wirehead 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wirehead-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2021 neuroneural/wirehead
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,4 @@
1
+ Metadata-Version: 2.1
2
+ Name: wirehead
3
+ Version: 0.1.0
4
+ License-File: LICENSE
@@ -0,0 +1,201 @@
1
+ # Wirehead #
2
+
3
+ Caching system for scaling of synthetic data generators using MongoDB
4
+
5
+ ![wirehead-diagram](assets/training-diagram.png)
6
+
7
+ ---
8
+
9
+ # I. Installation
10
+
11
+ ## 1. MongoDB Setup (For Development/Testing Only)
12
+
13
+ - [Ubuntu Setup](#a-quick-mongodb-setup-ubuntu)
14
+ - [macOS Setup](#b-quick-mongodb-setup-macos)
15
+
16
+ **Important Note:** The following instructions are for development and testing purposes only. For production deployments, please refer to the [official MongoDB documentation](https://www.mongodb.com/docs/manual/administration/install-community/) for secure and proper installation guidelines.
17
+
18
+ #### a. Quick MongoDB Setup ([Ubuntu](https://www.mongodb.com/docs/manual/tutorial/install-mongodb-on-ubuntu/)):
19
+
20
+ ```bash
21
+ sudo apt-get install gnupg curl
22
+ curl -fsSL https://www.mongodb.org/static/pgp/server-7.0.asc | \
23
+ sudo gpg -o /usr/share/keyrings/mongodb-server-7.0.gpg \
24
+ --dearmor
25
+ echo "deb [ arch=amd64,arm64 signed-by=/usr/share/keyrings/mongodb-server-7.0.gpg ] https://repo.mongodb.org/apt/ubuntu jammy/mongodb-org/7.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-7.0.list
26
+ sudo apt-get update
27
+ sudo apt-get install -y mongodb-org
28
+ ```
29
+
30
+ ```bash
31
+ # Run MongoDB
32
+ sudo systemctl start mongod
33
+ ```
34
+
35
+ ```bash
36
+ # Stop MongoDB
37
+ sudo systemctl stop mongod
38
+ ```
39
+
40
+ #### b. Quick MongoDB Setup ([MacOS](https://www.mongodb.com/docs/manual/tutorial/install-mongodb-on-os-x/)):
41
+
42
+ ```bash
43
+ brew tap mongodb/brew
44
+ brew update
45
+ brew install mongodb-community@7.0
46
+ ```
47
+
48
+ ```bash
49
+ # Run MongoDB
50
+ brew services start mongodb-community@7.0
51
+ ```
52
+
53
+ ```bash
54
+ # Run MongoDB
55
+ brew services stop mongodb-community@7.0
56
+ ```
57
+
58
+ **Warning:** These instructions are simplified for ease of setup in a development environment. For production use, ensure proper security measures, authentication, and follow best practices as outlined in the official MongoDB documentation.
59
+ Installing and deploying MongoDB
60
+
61
+ ## 2. Create virtual environment
62
+ ```bash
63
+ # Note:
64
+ # python version doesn't necessarily have to be 3.10
65
+ # but this gives better support for some generation pipelines
66
+
67
+ # Conda
68
+ conda create -n wirehead python=3.10
69
+ conda activate wirehead
70
+
71
+ # venv
72
+ python3.10 -m venv wirehead
73
+ source venv/bin/activate
74
+ ```
75
+
76
+ ## 3. Install wirehead:
77
+ ```bash
78
+ git clone git@github.com:neuroneural/wirehead.git
79
+ cd wirehead
80
+ pip install -e .
81
+ ```
82
+
83
+ ## 4. Run the test
84
+ ```bash
85
+ cd examples/unit
86
+ chmod +x test.sh
87
+ ./test.sh
88
+ ```
89
+
90
+ # II. Usage
91
+
92
+ See examples/unit for a minimal example
93
+
94
+ ## 1. Manager
95
+ ```python
96
+ from wirehead import WireheadManager
97
+
98
+ if __name__ == "__main__":
99
+ wirehead_runtime = WireheadManager(config_path="config.yaml")
100
+ wirehead_runtime.run_manager()
101
+ ```
102
+
103
+ ## 2. Generator
104
+
105
+ ```python
106
+ import numpy as np
107
+ from wirehead import WireheadGenerator
108
+
109
+ def create_generator():
110
+ while True:
111
+ img = np.random.rand(256,256,256)
112
+ lab = np.random.rand(256,256,256)
113
+ yield (img, lab)
114
+
115
+ if __name__ == "__main__":
116
+ brain_generator = create_generator()
117
+ wirehead_runtime = WireheadGenerator(
118
+ generator = brain_generator,
119
+ config_path = "config.yaml"
120
+ )
121
+ wirehead_runtime.run_generator()
122
+ ```
123
+
124
+ ## 3. Dataset
125
+ ```python
126
+ import torch
127
+ from wirehead import MongoheadDataset
128
+
129
+ dataset = MongoheadDataset(config_path = "config.yaml")
130
+
131
+ idx = [0]
132
+ data = dataset[idx]
133
+ sample, label = data[0]['input'], data[0]['label']
134
+ ```
135
+
136
+ # III. Config guide
137
+
138
+ All wirehead configs live inside yaml files, and must be specified when declaring wirehead manager, generator and dataset objects. For the system to work, all components must use the __same__ configs.
139
+
140
+ ## 1. Basic configs:
141
+ ```yaml
142
+ MONGOHOST -- IP address or hostname for machine running MongoDB instance
143
+ DBNAME -- MongoDB database name
144
+ PORT -- Port for MongoDB instance. Defaults to 27017
145
+ SWAP_CAP -- Size cap for read and write collections. bigger means bigger cache, and less frequent swaps. The total memory used by wirehead can be calculated with:
146
+ SWAP_CAP * SIZE OF YIELDED TUPLE * 2
147
+ ```
148
+
149
+ ## 2. Advanced configs:
150
+ ```yaml
151
+ SAMPLE -- Array of strings denoting name of samples in data tuple.
152
+ WRITE_COLLECTION -- Name of write collection (generators push to this)
153
+ READ_COLLECTION -- Name of read colletion (dataset reads from this)
154
+ COUNTER_COLLECTION -- Name of counter collection for manager metrics
155
+ TEMP_COLLECTION -- Name of temporary collection used for moving data during swap
156
+ CHUNKSIZE -- Number of megabytes used for chunking data
157
+ ```
158
+
159
+ ---
160
+
161
+ # IV. Generator example
162
+
163
+ See a simple example in [examples/unit/generator.py](examples/unit/generator.py) or a Synthseg example in [examples/synthseg/generator.py](examples/synthseg/generator.py)
164
+
165
+ Wirehead's [WireheadGenerator](https://github.com/neuroneural/wirehead/blob/main/wirehead/generator.py) object takes in a generator, which is a python generator function. This function yields a tuple containing numpy arrays. The number of samples in this tuple should match the number of strings specified in SAMPLE in config.yaml
166
+
167
+ ## 1. Set SAMPLE in "config.yaml" (note the number of keys)
168
+ ```yaml
169
+ SAMPLE: ["a", "b"]
170
+ ```
171
+
172
+ ## 2. Create a generator function, which yields the same number of objects
173
+ ```python
174
+ def create_generator():
175
+ while True:
176
+ a = np.random.rand(256,256,256)
177
+ b = np.random.rand(256,256,256)
178
+ yield (a, b)
179
+ ```
180
+
181
+ ## 3. Insert config file path and generator function into WireheadGenerator
182
+ ```python
183
+ generator = create_generator()
184
+ runtime = WireheadGenerator(
185
+ generator = generator,
186
+ config_path = "config.yaml"
187
+ )
188
+ ```
189
+
190
+ ## 4. Press play
191
+ ```python
192
+ runtime.run_generator() # runs an infinite loop
193
+ ```
194
+
195
+ ---
196
+
197
+ # Citation/Contact
198
+
199
+ This code is under [MIT](https://github.com/neuroneural/wirehead/blob/main/LICENSE) licensing
200
+
201
+ If you have any questions specific to the Wirehead pipeline, please raise an issue or contact us at mdoan4@gsu.edu
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,21 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='wirehead',
5
+ version='0.1.0',
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ 'pymongo',
9
+ 'torch',
10
+ 'numpy',
11
+ 'PyYaml',
12
+ # List your package dependencies here
13
+ # Example: 'numpy>=1.19.0',
14
+ ],
15
+ entry_points={
16
+ 'console_scripts': [
17
+ # Define any command-line entry points here
18
+ # Example: 'your-command = your_package.module:function',
19
+ ],
20
+ },
21
+ )
@@ -0,0 +1,4 @@
1
+ from wirehead.dataset import MongoheadDataset
2
+ from wirehead.dataset import MongoTupleheadDataset
3
+ from wirehead.manager import WireheadManager
4
+ from wirehead.generator import WireheadGenerator
@@ -0,0 +1,253 @@
1
+ """ Wirehead dataset class for MongoDB """
2
+
3
+ import io
4
+ import os
5
+ import sys
6
+ import time
7
+ import yaml
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from pymongo import MongoClient
11
+ from pymongo.errors import OperationFailure, OperationFailure
12
+
13
+
14
+ def unit_interval_normalize(img):
15
+ """Unit interval preprocessing"""
16
+ img = (img - img.min()) / (img.max() - img.min())
17
+ return img
18
+
19
+
20
+ def quantile_normalize(img, qmin=0.01, qmax=0.99):
21
+ """Unit interval preprocessing"""
22
+ img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
23
+ return img
24
+
25
+
26
+ def binary_to_tensor(tensor_binary):
27
+ """ Converts a binary io buffer to a torch tensor """
28
+ buffer = io.BytesIO(tensor_binary)
29
+ tensor = torch.load(buffer)
30
+ return tensor
31
+
32
+
33
+ class MongoheadDataset(Dataset):
34
+ """
35
+ A dataset for fetching batches of records from a MongoDB
36
+ """
37
+
38
+ def __init__(self,
39
+ config_path="",
40
+ timeout=60,
41
+ collection=None,
42
+ sample=("data", "label"),
43
+ transform=binary_to_tensor,
44
+ normalize=lambda x: x,
45
+ id="id",
46
+ keeptrying=True):
47
+ """Constructor
48
+ :param config_path: path to wirehead config .yaml file
49
+ :param indices: a set of indices to be extracted from the collection
50
+ :param transform: a function to be applied to each extracted record
51
+ :param collection: pymongo collection to be used
52
+ :param sample: a pair of fields to be fetched as `input` and `label`
53
+ , e.g. (`T1`, `label104`)
54
+ :param id: the field to be used as an index. The `indices` are values of this field
55
+ :param keeptrying: whether to attempt a refetch if first attempt fails
56
+ :returns: an object of MongoheadDataset class
57
+ """
58
+ self.id = id
59
+ self.normalize = normalize
60
+ self.transform = transform
61
+ self.keeptrying = keeptrying # retries if fetch fails
62
+ self.fields = {"id": 1, "chunk": 1, "kind": 1, "chunk_id": 1}
63
+ self.timeout = timeout
64
+
65
+ if config_path != "" and os.path.exists(config_path):
66
+ self.load_from_yaml(config_path)
67
+
68
+ else:
69
+ self.collection = collection
70
+ self.sample = sample
71
+
72
+ self.indices = self.get_indeces()
73
+
74
+ def load_from_yaml(self, config_path):
75
+ """
76
+ Loads config options from config_path
77
+ """
78
+ print("Dataset: config loaded from " + config_path)
79
+ with open(config_path, "r") as file:
80
+ config = yaml.safe_load(file)
81
+ dbname = config.get('DBNAME')
82
+ mongohost = config.get('MONGOHOST')
83
+ port = config.get('PORT') if config.get('PORT') is not None else 27017
84
+ client = MongoClient("mongodb://" + mongohost + ":" + str(port))
85
+
86
+ db = client[dbname]
87
+ self.wait_for_data(db)
88
+ read_collection = config.get("READ_COLLECTION")
89
+ self.collection = db[read_collection]
90
+ self.sample = tuple(config.get("SAMPLE"))
91
+
92
+ def wait_for_data(self, db, timeout=600, check_interval=10):
93
+ """
94
+ Prevents data object from reading before data is ready.
95
+ Raises:
96
+ TimeoutError: If the wait time exceeds the timeout
97
+ """
98
+ status_collection = db["status"]
99
+ start_time = time.time()
100
+
101
+ def check_status():
102
+ try:
103
+ latest_status = status_collection.find_one(sort=[("_id", -1)])
104
+ if latest_status is None:
105
+ print("Dataset: Database is empty, waiting for data...")
106
+ return False
107
+ if not latest_status.get("swapped", False):
108
+ print("Dataset: Swap has not happened, waiting...")
109
+ return False
110
+ return True
111
+ except (ConnectionFailure, OperationFailure) as e:
112
+ print(f"Dataset: Error accessing database: {e}")
113
+ return False
114
+
115
+ while time.time() - start_time < timeout:
116
+ if check_status():
117
+ print("Dataset: Data is ready")
118
+ return
119
+ time.sleep(check_interval)
120
+
121
+ raise TimeoutError("Dataset: Waited too long for data to be ready")
122
+
123
+ def get_indeces(self):
124
+ """
125
+ Retrieve the index array of samples in read collection
126
+ """
127
+ last_post = self.collection['bin'].find_one(sort=[(self.id, -1)])
128
+
129
+ if last_post is None:
130
+ print("Empty collection, exiting")
131
+ sys.exit()
132
+ num_examples = int(last_post[self.id] + 1)
133
+ return range(num_examples)
134
+
135
+ def __len__(self):
136
+ return len(self.indices)
137
+
138
+ def make_serial(self, samples_for_id, kind):
139
+ """
140
+ Converts collection chunks into a contiguous byte sequence
141
+ """
142
+ return b"".join([
143
+ sample["chunk"] for sample in sorted(
144
+ (sample for sample in samples_for_id if sample["kind"] == kind),
145
+ key=lambda x: x["chunk_id"],
146
+ )
147
+ ])
148
+
149
+ def retry_on_eof_error(retry_count, verbose=False):
150
+ """
151
+ Error handling for reads that happen mid swap
152
+ """
153
+
154
+ def decorator(func):
155
+
156
+ def wrapper(self, batch, *args, **kwargs):
157
+ myException = Exception # Default Exception if not overwritten
158
+ for attempt in range(retry_count):
159
+ try:
160
+ return func(self, batch, *args, **kwargs)
161
+ except (
162
+ EOFError,
163
+ OperationFailure,
164
+ ) as exception: # Specifically catching EOFError
165
+ if self.keeptrying:
166
+ if verbose:
167
+ print(
168
+ f"EOFError caught. Retrying {attempt+1}/{retry_count}"
169
+ )
170
+ time.sleep(1)
171
+ continue
172
+ else:
173
+ raise exception
174
+ raise myException("Failed after multiple retries.")
175
+
176
+ return wrapper
177
+
178
+ return decorator
179
+
180
+ @retry_on_eof_error(retry_count=3, verbose=True)
181
+ def __getitem__(self, batch):
182
+ """
183
+ Fetch all samples for ids in the batch and where 'kind' is either
184
+ data or label as specified by the sample parameter
185
+ """
186
+ samples = list(self.collection["bin"].find(
187
+ {
188
+ self.id: {
189
+ "$in": [self.indices[_] for _ in batch]
190
+ },
191
+ "kind": {
192
+ "$in": self.sample
193
+ },
194
+ },
195
+ self.fields,
196
+ ))
197
+ results = {}
198
+ for id in batch:
199
+ # Separate samples for this id
200
+ samples_for_id = [
201
+ sample for sample in samples if sample[self.id] == self.indices[id]
202
+ ]
203
+
204
+ # Separate processing for each 'kind'
205
+ data = self.make_serial(samples_for_id, self.sample[0])
206
+ label = self.make_serial(samples_for_id, self.sample[1])
207
+
208
+ # Add to results
209
+ results[id] = {
210
+ "input": self.normalize(self.transform(data).float()),
211
+ "label": self.transform(label),
212
+ }
213
+ return results
214
+
215
+ class MongoTupleheadDataset(MongoheadDataset):
216
+ """
217
+ A dataset for fetching batches of records from a MongoDB
218
+ Returns a tuple instead of a dict
219
+ """
220
+
221
+ def __init__(self, *args, **kwargs):
222
+ super().__init__(*args, **kwargs)
223
+
224
+ @MongoheadDataset.retry_on_eof_error(retry_count=3, verbose=True)
225
+ def __getitem__(self, batch):
226
+ """
227
+ Fetch all samples for ids in the batch and return as tuples
228
+ """
229
+ samples = list(self.collection["bin"].find(
230
+ {
231
+ self.id: {
232
+ "$in": [self.indices[_] for _ in batch]
233
+ },
234
+ "kind": {
235
+ "$in": self.sample
236
+ },
237
+ },
238
+ self.fields,
239
+ ))
240
+ results = []
241
+ for id in batch:
242
+ samples_for_id = [
243
+ sample for sample in samples if sample[self.id] == self.indices[id]
244
+ ]
245
+
246
+ data = self.make_serial(samples_for_id, self.sample[0])
247
+ label = self.make_serial(samples_for_id, self.sample[1])
248
+
249
+ data = self.normalize(self.transform(data).float())
250
+ label = self.transform(label)
251
+
252
+ results.append((data, label))
253
+ return results
@@ -0,0 +1,127 @@
1
+ """ Wirehead Generator Class """
2
+
3
+ import io
4
+ import os
5
+ import time
6
+ import yaml
7
+ import bson
8
+ import torch
9
+ from pymongo import MongoClient, ReturnDocument
10
+
11
+
12
+ class WireheadGenerator():
13
+ """
14
+ Wirehead runtime class, which wraps around the generator
15
+ and manager runtimes.
16
+ """
17
+
18
+ def __init__(self, generator, config_path, n_samples = 1000):
19
+ if config_path is None or os.path.exists(config_path) is False:
20
+ print("No valid config specified, exiting")
21
+ return
22
+ self.load_from_yaml(config_path)
23
+ self.generator = generator
24
+ self.n_samples = n_samples
25
+
26
+ def load_from_yaml(self, config_path):
27
+ """ Loads manager configs from config_path """
28
+ with open(config_path, 'r', encoding='utf-8') as file:
29
+ config = yaml.safe_load(file)
30
+ dbname = config.get('DBNAME')
31
+ mongohost = config.get('MONGOHOST')
32
+ port = config.get('PORT') if config.get('PORT') is not None else 27017
33
+ client = MongoClient("mongodb://" + mongohost + ":" + str(port))
34
+
35
+ self.db = client[dbname]
36
+ self.swap_cap = config.get('SWAP_CAP')
37
+ self.sample = tuple(config.get("SAMPLE"))
38
+ self.chunksize = config.get("CHUNKSIZE")
39
+ self.collectionw = config.get("WRITE_COLLECTION") + ".bin"
40
+ self.collectionc = config.get("COUNTER_COLLECTION")
41
+
42
+ def chunkify(self, data, index):
43
+ """
44
+ Converts a tuple of tensors and their labels into
45
+ a list of chunks of serialized objects for mongodb
46
+ """
47
+
48
+ def chunk_binobj(tensor_compressed, idx, kind, chunksize):
49
+ """
50
+ Convert chunksize from megabytes to bytes
51
+ """
52
+ chunksize_bytes = chunksize * 1024 * 1024
53
+ # Calculate the number of chunks
54
+ num_chunks = len(tensor_compressed) // chunksize_bytes
55
+ if len(tensor_compressed) % chunksize_bytes != 0:
56
+ num_chunks += 1
57
+ # Yield chunks
58
+ for i in range(num_chunks):
59
+ start = i * chunksize_bytes
60
+ end = min((i + 1) * chunksize_bytes, len(tensor_compressed))
61
+ chunk = tensor_compressed[start:end]
62
+ yield {
63
+ "id": idx,
64
+ "chunk_id": i,
65
+ "kind": kind,
66
+ "chunk": bson.Binary(chunk),
67
+ }
68
+
69
+ def tensor2bin(tensor):
70
+ """
71
+ Seralize a torch tensor into an IO buffer
72
+ """
73
+ tensor_1d = tensor.to(torch.uint8)
74
+ buffer = io.BytesIO()
75
+ torch.save(tensor_1d, buffer)
76
+ tensor_binary = buffer.getvalue()
77
+ return tensor_binary
78
+
79
+ chunks = []
80
+ binobj = data
81
+ kinds = self.sample
82
+ for i, kind in enumerate(kinds):
83
+ chunks += list(
84
+ chunk_binobj(tensor2bin(torch.from_numpy(binobj[i])), index, kind,
85
+ self.chunksize))
86
+ return chunks
87
+
88
+ def push_chunks(self, chunks):
89
+ """ Pushes chunkified tensors to mongodb, with error handling"""
90
+ collection_bin = self.db[self.collectionw]
91
+ try:
92
+ collection_bin.insert_many(chunks)
93
+ except Exception as exception:
94
+ print(f"Generator: An error occurred: {exception}, are you swapping?")
95
+ time.sleep(1)
96
+
97
+ def get_current_idx(self):
98
+ """ Get current index of sample in write collection """
99
+ dbc = self.db[self.collectionc]
100
+ counter_doc = dbc.find_one_and_update(
101
+ {"_id": "uniqueFieldCounter"},
102
+ {"$inc": {
103
+ "sequence_value": 1
104
+ }},
105
+ return_document=ReturnDocument.BEFORE,
106
+ )
107
+ return counter_doc["sequence_value"]
108
+
109
+ def generate_and_insert(self):
110
+ """ Fetch from generator and inserts into mongodb """
111
+ # 0. Fetch data from generator
112
+ data = next(self.generator)
113
+ # 1. Get the correct index for this current sample
114
+ index = self.get_current_idx()
115
+ # 2. Turn the data into a list of serialized chunks
116
+ chunks = self.chunkify(data, index)
117
+ # 3. Push to mongodb + error handling
118
+ if index < self.swap_cap:
119
+ self.push_chunks(chunks)
120
+
121
+ def run_generator(self):
122
+ """ Initializes and runs a SynthSeg brain generator in a loop,
123
+ preprocesses, then pushes to mongoDB"""
124
+ print("Generator: Initialized")
125
+ n_samples = self.n_samples
126
+ for _ in range(n_samples):
127
+ self.generate_and_insert()
@@ -0,0 +1,117 @@
1
+ """ Wirehead Manager Class """
2
+
3
+ import os
4
+ import time
5
+ import yaml
6
+ from pymongo import MongoClient, ASCENDING
7
+
8
+
9
+ class WireheadManager():
10
+ """
11
+ Manages the state of the mongo collections in Wirehead.
12
+
13
+ :param config_path: path to yaml file containing wirehead configs
14
+ """
15
+
16
+ def __init__(self, config_path):
17
+ if config_path is None or os.path.exists(config_path) is False:
18
+ print("No valid config specified, exiting")
19
+ return
20
+ self.load_from_yaml(config_path)
21
+
22
+ def load_from_yaml(self, config_path):
23
+ """
24
+ Loads manager configs from config_path.
25
+ """
26
+ with open(config_path, 'r', encoding='utf-8') as file:
27
+ config = yaml.safe_load(file)
28
+
29
+ dbname = config.get('DBNAME')
30
+ mongohost = config.get('MONGOHOST')
31
+ port = config.get('PORT') if config.get('PORT') is not None else 27017
32
+ client = MongoClient("mongodb://" + mongohost + ":" + str(port))
33
+
34
+ self.db = client[dbname]
35
+ self.swap_cap = config.get('SWAP_CAP')
36
+ self.collectionw = config.get("WRITE_COLLECTION") + ".bin"
37
+ self.collectionr = config.get("READ_COLLECTION") + ".bin"
38
+ self.collectionc = config.get("COUNTER_COLLECTION")
39
+ self.collectiont = config.get("TEMP_COLLECTION") + ".bin"
40
+
41
+ def run_manager(self):
42
+ """
43
+ Initializes the database manager, swaps and cleans the database whenever swap_cap is hit.
44
+ """
45
+ print("Manager: Initialized")
46
+ self.db["status"].insert_one({"swapped": False})
47
+ self.reset_counter_and_collection()
48
+ generated = 0
49
+ while True:
50
+ generated = self.watch_and_swap(generated)
51
+
52
+ def verify_collection_integrity(self, collection):
53
+ """
54
+ Verifies collection contains contiguous elements with id 0..swap_cap
55
+ """
56
+ unique_ids_count = len(collection.distinct("id"))
57
+ assert (
58
+ unique_ids_count == self.swap_cap
59
+ ), f"Manager: Expected {self.swap_cap} unique ids, found {unique_ids_count}"
60
+ expected_ids_set = set(range(self.swap_cap))
61
+ actual_ids_set = set(collection.distinct("id"))
62
+ assert (expected_ids_set == actual_ids_set
63
+ ), "Manager: The ids aren't continuous from 0 to self.swap_cap - 1"
64
+
65
+ def reset_counter_and_collection(self):
66
+ """
67
+ Delete all documents in the main collection that have creeped in
68
+ between the renaming and now. This operation is within a transaction.
69
+ """
70
+ dbw = self.db[self.collectionw]
71
+ dbc = self.db[self.collectionc]
72
+ dbw.delete_many({}) # wipe the write collection
73
+ # Reset the counter to zero
74
+ _result = dbc.update_one(
75
+ {"_id": "uniqueFieldCounter"}, # Query part: the document to match
76
+ {"$set": {
77
+ "sequence_value": 0
78
+ }}, # Update part: what to set if the document is matched/found
79
+ upsert=True,
80
+ )
81
+ dbw.delete_many({})
82
+ dbw.create_index([("id", ASCENDING)], background=True)
83
+
84
+ def swap(self, generated):
85
+ """
86
+ Moves data from write collection to read collection
87
+ Deletes old write collection
88
+ Maintains data integrity in between
89
+ """
90
+ time.sleep(2) # Buffer for incomplete ops
91
+ generated += self.swap_cap
92
+ print("\n----swap----")
93
+ print(f"Manager: Generated samples so far {generated}")
94
+ self.db[self.collectionw].rename(self.collectiont, dropTarget=True)
95
+ # Now atomically reset the counter to 0 and delete whatever records
96
+ # may have been written between the execution of the previous line
97
+ # and the next
98
+ self.reset_counter_and_collection() # this is atomic
99
+ result = self.db[self.collectiont].delete_many(
100
+ {"id": {
101
+ "$gt": self.swap_cap - 1
102
+ }})
103
+ # Print the result of the deletion
104
+ print(f"Manager: Documents deleted: {result.deleted_count}")
105
+ self.verify_collection_integrity(self.db[self.collectiont])
106
+ self.db[self.collectiont].rename(self.collectionr, dropTarget=True)
107
+ self.db["status"].insert_one({"swapped": True})
108
+ return generated
109
+
110
+ def watch_and_swap(self, generated):
111
+ """
112
+ Watch the write collection and swap when full
113
+ """
114
+ counter_doc = self.db[self.collectionc].find_one({"_id": "uniqueFieldCounter"})
115
+ if counter_doc["sequence_value"] >= self.swap_cap: # watch
116
+ return self.swap(generated) # swap
117
+ return generated
@@ -0,0 +1,4 @@
1
+ Metadata-Version: 2.1
2
+ Name: wirehead
3
+ Version: 0.1.0
4
+ License-File: LICENSE
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ wirehead/__init__.py
5
+ wirehead/dataset.py
6
+ wirehead/generator.py
7
+ wirehead/manager.py
8
+ wirehead.egg-info/PKG-INFO
9
+ wirehead.egg-info/SOURCES.txt
10
+ wirehead.egg-info/dependency_links.txt
11
+ wirehead.egg-info/requires.txt
12
+ wirehead.egg-info/top_level.txt
@@ -0,0 +1,4 @@
1
+ pymongo
2
+ torch
3
+ numpy
4
+ PyYaml
@@ -0,0 +1 @@
1
+ wirehead