redisbench-admin 0.11.63__py3-none-any.whl → 0.11.65__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- redisbench_admin/run/ann/pkg/.dockerignore +2 -0
- redisbench_admin/run/ann/pkg/.git +1 -0
- redisbench_admin/run/ann/pkg/.github/workflows/benchmarks.yml +100 -0
- redisbench_admin/run/ann/pkg/.gitignore +21 -0
- redisbench_admin/run/ann/pkg/LICENSE +21 -0
- redisbench_admin/run/ann/pkg/README.md +157 -0
- redisbench_admin/run/ann/pkg/algos.yaml +1294 -0
- redisbench_admin/run/ann/pkg/algosP.yaml +67 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/__init__.py +2 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/__init__.py +0 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/annoy.py +26 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/balltree.py +22 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/base.py +36 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/bruteforce.py +110 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/ckdtree.py +17 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/datasketch.py +29 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/definitions.py +187 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/diskann.py +190 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/dolphinnpy.py +31 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/dummy_algo.py +25 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/elasticsearch.py +107 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/elastiknn.py +124 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/faiss.py +124 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/faiss_gpu.py +61 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/faiss_hnsw.py +39 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/flann.py +27 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/hnswlib.py +36 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/kdtree.py +22 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/kgraph.py +39 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/lshf.py +25 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/milvus.py +99 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/mrpt.py +41 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/n2.py +28 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/nearpy.py +48 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/nmslib.py +74 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/onng_ngt.py +100 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/opensearchknn.py +107 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/panng_ngt.py +79 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/pinecone.py +39 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/puffinn.py +45 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/pynndescent.py +115 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/qg_ngt.py +102 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/redisearch.py +90 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/rpforest.py +20 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/scann.py +34 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/sptag.py +28 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/subprocess.py +246 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/vald.py +149 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/vecsim-hnsw.py +43 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/vespa.py +47 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/constants.py +1 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/data.py +48 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/datasets.py +620 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/distance.py +53 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/main.py +325 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/__init__.py +2 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/metrics.py +183 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/plot_variants.py +17 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/utils.py +165 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/results.py +71 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/runner.py +333 -0
- redisbench_admin/run/ann/pkg/create_dataset.py +12 -0
- redisbench_admin/run/ann/pkg/create_hybrid_dataset.py +147 -0
- redisbench_admin/run/ann/pkg/create_text_to_image_ds.py +117 -0
- redisbench_admin/run/ann/pkg/create_website.py +272 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile +11 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.annoy +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.datasketch +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.diskann +29 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.diskann_pq +31 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.dolphinn +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.elasticsearch +45 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.elastiknn +61 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.faiss +18 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.flann +10 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.hnswlib +10 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.kgraph +6 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.mih +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.milvus +27 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.mrpt +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.n2 +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.nearpy +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.ngt +13 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.nmslib +10 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.opensearchknn +43 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.puffinn +6 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.pynndescent +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.redisearch +18 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.rpforest +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.scann +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.scipy +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.sklearn +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.sptag +30 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.vald +8 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.vespa +17 -0
- redisbench_admin/run/ann/pkg/install.py +70 -0
- redisbench_admin/run/ann/pkg/logging.conf +34 -0
- redisbench_admin/run/ann/pkg/multirun.py +298 -0
- redisbench_admin/run/ann/pkg/plot.py +159 -0
- redisbench_admin/run/ann/pkg/protocol/bf-runner +10 -0
- redisbench_admin/run/ann/pkg/protocol/bf-runner.py +204 -0
- redisbench_admin/run/ann/pkg/protocol/ext-add-query-metric.md +51 -0
- redisbench_admin/run/ann/pkg/protocol/ext-batch-queries.md +77 -0
- redisbench_admin/run/ann/pkg/protocol/ext-prepared-queries.md +77 -0
- redisbench_admin/run/ann/pkg/protocol/ext-query-parameters.md +47 -0
- redisbench_admin/run/ann/pkg/protocol/specification.md +194 -0
- redisbench_admin/run/ann/pkg/requirements.txt +14 -0
- redisbench_admin/run/ann/pkg/requirements_py38.txt +11 -0
- redisbench_admin/run/ann/pkg/results/fashion-mnist-784-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/results/gist-960-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/results/glove-100-angular.png +0 -0
- redisbench_admin/run/ann/pkg/results/glove-25-angular.png +0 -0
- redisbench_admin/run/ann/pkg/results/lastfm-64-dot.png +0 -0
- redisbench_admin/run/ann/pkg/results/mnist-784-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/results/nytimes-256-angular.png +0 -0
- redisbench_admin/run/ann/pkg/results/sift-128-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/run.py +12 -0
- redisbench_admin/run/ann/pkg/run_algorithm.py +3 -0
- redisbench_admin/run/ann/pkg/templates/chartjs.template +102 -0
- redisbench_admin/run/ann/pkg/templates/detail_page.html +23 -0
- redisbench_admin/run/ann/pkg/templates/general.html +58 -0
- redisbench_admin/run/ann/pkg/templates/latex.template +30 -0
- redisbench_admin/run/ann/pkg/templates/summary.html +60 -0
- redisbench_admin/run/ann/pkg/test/__init__.py +0 -0
- redisbench_admin/run/ann/pkg/test/test-jaccard.py +19 -0
- redisbench_admin/run/ann/pkg/test/test-metrics.py +99 -0
- redisbench_admin/run_async/run_async.py +2 -2
- redisbench_admin/run_local/run_local.py +2 -2
- redisbench_admin/run_remote/run_remote.py +9 -5
- {redisbench_admin-0.11.63.dist-info → redisbench_admin-0.11.65.dist-info}/METADATA +2 -5
- redisbench_admin-0.11.65.dist-info/RECORD +243 -0
- {redisbench_admin-0.11.63.dist-info → redisbench_admin-0.11.65.dist-info}/WHEEL +1 -1
- redisbench_admin-0.11.63.dist-info/RECORD +0 -117
- {redisbench_admin-0.11.63.dist-info/licenses → redisbench_admin-0.11.65.dist-info}/LICENSE +0 -0
- {redisbench_admin-0.11.63.dist-info → redisbench_admin-0.11.65.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
|
|
3
|
+
import h5py
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import traceback
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_result_filename(dataset=None, count=None, definition=None,
|
|
11
|
+
query_arguments=None, batch_mode=False, id=0):
|
|
12
|
+
d = ['results']
|
|
13
|
+
if dataset:
|
|
14
|
+
d.append(dataset)
|
|
15
|
+
if count:
|
|
16
|
+
d.append(str(count))
|
|
17
|
+
if definition:
|
|
18
|
+
d.append(definition.algorithm + ('-batch' if batch_mode else ''))
|
|
19
|
+
data = definition.arguments + query_arguments
|
|
20
|
+
for i in range(len(data)):
|
|
21
|
+
if isinstance(data[i], dict):
|
|
22
|
+
data[i] = {k:data[i][k] for k in data[i] if data[i][k] is not None and k != 'auth'}
|
|
23
|
+
data.append('client')
|
|
24
|
+
data.append(id)
|
|
25
|
+
d.append(re.sub(r'\W+', '_', json.dumps(data, sort_keys=True)).strip('_') + ".hdf5")
|
|
26
|
+
return os.path.join(*d)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def store_results(dataset, count, definition, query_arguments, attrs, results,
|
|
30
|
+
batch, id=0):
|
|
31
|
+
fn = get_result_filename(
|
|
32
|
+
dataset, count, definition, query_arguments, batch, id)
|
|
33
|
+
head, tail = os.path.split(fn)
|
|
34
|
+
if not os.path.isdir(head):
|
|
35
|
+
os.makedirs(head)
|
|
36
|
+
f = h5py.File(fn, 'w')
|
|
37
|
+
for k, v in attrs.items():
|
|
38
|
+
f.attrs[k] = v
|
|
39
|
+
times = f.create_dataset('times', (len(results),), 'f')
|
|
40
|
+
neighbors = f.create_dataset('neighbors', (len(results), count), 'i')
|
|
41
|
+
distances = f.create_dataset('distances', (len(results), count), 'f')
|
|
42
|
+
for i, (time, ds) in enumerate(results):
|
|
43
|
+
times[i] = time
|
|
44
|
+
neighbors[i] = [n for n, d in ds] + [-1] * (count - len(ds))
|
|
45
|
+
distances[i] = [d for n, d in ds] + [float('inf')] * (count - len(ds))
|
|
46
|
+
f.close()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def load_all_results(dataset=None, count=None, batch_mode=False):
|
|
50
|
+
for root, _, files in os.walk(get_result_filename(dataset, count)):
|
|
51
|
+
for fn in files:
|
|
52
|
+
if os.path.splitext(fn)[-1] != '.hdf5':
|
|
53
|
+
continue
|
|
54
|
+
try:
|
|
55
|
+
f = h5py.File(os.path.join(root, fn), 'r+')
|
|
56
|
+
properties = dict(f.attrs)
|
|
57
|
+
if batch_mode != properties['batch_mode']:
|
|
58
|
+
continue
|
|
59
|
+
yield properties, f
|
|
60
|
+
f.close()
|
|
61
|
+
except:
|
|
62
|
+
print('Was unable to read', fn)
|
|
63
|
+
traceback.print_exc()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_unique_algorithms():
|
|
67
|
+
algorithms = set()
|
|
68
|
+
for batch_mode in [False, True]:
|
|
69
|
+
for properties, _ in load_all_results(batch_mode=batch_mode):
|
|
70
|
+
algorithms.add(properties['algo'])
|
|
71
|
+
return algorithms
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
6
|
+
import time
|
|
7
|
+
import traceback
|
|
8
|
+
import inspect
|
|
9
|
+
import h5py
|
|
10
|
+
|
|
11
|
+
import colors
|
|
12
|
+
import docker
|
|
13
|
+
import numpy
|
|
14
|
+
import psutil
|
|
15
|
+
|
|
16
|
+
from ann_benchmarks.algorithms.definitions import (Definition,
|
|
17
|
+
instantiate_algorithm)
|
|
18
|
+
from ann_benchmarks.datasets import get_dataset, DATASETS
|
|
19
|
+
from ann_benchmarks.distance import metrics, dataset_transform
|
|
20
|
+
from ann_benchmarks.results import get_result_filename, store_results
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def run_individual_query(algo, X_train, X_test, distance, count, run_count,
|
|
24
|
+
batch):
|
|
25
|
+
prepared_queries = \
|
|
26
|
+
(batch and hasattr(algo, "prepare_batch_query")) or \
|
|
27
|
+
((not batch) and hasattr(algo, "prepare_query"))
|
|
28
|
+
|
|
29
|
+
best_search_time = float('inf')
|
|
30
|
+
for i in range(run_count):
|
|
31
|
+
print('Run %d/%d...' % (i + 1, run_count))
|
|
32
|
+
# a bit dumb but can't be a scalar since of Python's scoping rules
|
|
33
|
+
n_items_processed = [0]
|
|
34
|
+
|
|
35
|
+
def single_query(v):
|
|
36
|
+
if prepared_queries:
|
|
37
|
+
algo.prepare_query(v, count)
|
|
38
|
+
start = time.time()
|
|
39
|
+
algo.run_prepared_query()
|
|
40
|
+
total = (time.time() - start)
|
|
41
|
+
candidates = algo.get_prepared_query_results()
|
|
42
|
+
else:
|
|
43
|
+
start = time.time()
|
|
44
|
+
candidates = algo.query(v, count)
|
|
45
|
+
total = (time.time() - start)
|
|
46
|
+
candidates = [(int(idx), float(metrics[distance]['distance'](v, X_train[idx]))) # noqa
|
|
47
|
+
for idx in candidates]
|
|
48
|
+
n_items_processed[0] += 1
|
|
49
|
+
if n_items_processed[0] % 1000 == 0:
|
|
50
|
+
print('Processed %d/%d queries...' % (n_items_processed[0], len(X_test)))
|
|
51
|
+
if len(candidates) > count:
|
|
52
|
+
print('warning: algorithm %s returned %d results, but count'
|
|
53
|
+
' is only %d)' % (algo, len(candidates), count))
|
|
54
|
+
return (total, candidates)
|
|
55
|
+
|
|
56
|
+
def batch_query(X):
|
|
57
|
+
if prepared_queries:
|
|
58
|
+
algo.prepare_batch_query(X, count)
|
|
59
|
+
start = time.time()
|
|
60
|
+
algo.run_batch_query()
|
|
61
|
+
total = (time.time() - start)
|
|
62
|
+
else:
|
|
63
|
+
start = time.time()
|
|
64
|
+
algo.batch_query(X, count)
|
|
65
|
+
total = (time.time() - start)
|
|
66
|
+
results = algo.get_batch_results()
|
|
67
|
+
candidates = [[(int(idx), float(metrics[distance]['distance'](v, X_train[idx]))) # noqa
|
|
68
|
+
for idx in single_results]
|
|
69
|
+
for v, single_results in zip(X, results)]
|
|
70
|
+
return [(total / float(len(X)), v) for v in candidates]
|
|
71
|
+
|
|
72
|
+
if batch:
|
|
73
|
+
results = batch_query(X_test)
|
|
74
|
+
else:
|
|
75
|
+
results = [single_query(x) for x in X_test]
|
|
76
|
+
|
|
77
|
+
total_time = sum(time for time, _ in results)
|
|
78
|
+
total_candidates = sum(len(candidates) for _, candidates in results)
|
|
79
|
+
search_time = total_time / len(X_test)
|
|
80
|
+
avg_candidates = total_candidates / len(X_test)
|
|
81
|
+
best_search_time = min(best_search_time, search_time)
|
|
82
|
+
print("qps:", len(X_test)/total_time)
|
|
83
|
+
|
|
84
|
+
verbose = hasattr(algo, "query_verbose")
|
|
85
|
+
attrs = {
|
|
86
|
+
"batch_mode": batch,
|
|
87
|
+
"best_search_time": best_search_time,
|
|
88
|
+
"candidates": avg_candidates,
|
|
89
|
+
"expect_extra": verbose,
|
|
90
|
+
"name": str(algo),
|
|
91
|
+
"run_count": run_count,
|
|
92
|
+
"distance": distance,
|
|
93
|
+
"count": int(count)
|
|
94
|
+
}
|
|
95
|
+
additional = algo.get_additional()
|
|
96
|
+
for k in additional:
|
|
97
|
+
attrs[k] = additional[k]
|
|
98
|
+
return (attrs, results)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def run(definition, dataset, count, run_count, batch, build_only, test_only, num_clients, id):
|
|
102
|
+
algo = instantiate_algorithm(definition)
|
|
103
|
+
assert not definition.query_argument_groups \
|
|
104
|
+
or hasattr(algo, "set_query_arguments"), """\
|
|
105
|
+
error: query argument groups have been specified for %s.%s(%s), but the \
|
|
106
|
+
algorithm instantiated from it does not implement the set_query_arguments \
|
|
107
|
+
function""" % (definition.module, definition.constructor, definition.arguments)
|
|
108
|
+
|
|
109
|
+
D, dimension = get_dataset(dataset)
|
|
110
|
+
X_train, X_test = dataset_transform(D)
|
|
111
|
+
distance = D.attrs['distance']
|
|
112
|
+
print('got a train set of size (%d * %d)' % (X_train.shape[0], dimension))
|
|
113
|
+
|
|
114
|
+
hybrid_buckets = None
|
|
115
|
+
if 'bucket_names' in D.attrs:
|
|
116
|
+
hybrid_buckets = {}
|
|
117
|
+
bucket_names = D.attrs['bucket_names']
|
|
118
|
+
for bucket_name in bucket_names:
|
|
119
|
+
bucket_dict = {}
|
|
120
|
+
bucket_dict['ids'] = numpy.array(D[f'{bucket_name}_ids'])
|
|
121
|
+
bucket_dict['text'] = D[bucket_name]['text'][()]
|
|
122
|
+
bucket_dict['number'] = D[bucket_name]['number'][()]
|
|
123
|
+
hybrid_buckets[bucket_name] = bucket_dict
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
prepared_queries = False
|
|
127
|
+
if hasattr(algo, "supports_prepared_queries"):
|
|
128
|
+
prepared_queries = algo.supports_prepared_queries()
|
|
129
|
+
|
|
130
|
+
if not test_only:
|
|
131
|
+
per_client = len(X_train) // num_clients
|
|
132
|
+
offset = per_client * (id - 1)
|
|
133
|
+
fit_kwargs = {}
|
|
134
|
+
if "offset" and "limit" in inspect.getfullargspec(algo.fit)[0]:
|
|
135
|
+
fit_kwargs['offset']=offset
|
|
136
|
+
if num_clients != id:
|
|
137
|
+
fit_kwargs['limit']=offset + per_client
|
|
138
|
+
if hybrid_buckets:
|
|
139
|
+
fit_kwargs['hybrid_buckets']=hybrid_buckets
|
|
140
|
+
|
|
141
|
+
t0 = time.time()
|
|
142
|
+
memory_usage_before = algo.get_memory_usage()
|
|
143
|
+
algo.fit(X_train, **fit_kwargs)
|
|
144
|
+
build_time = time.time() - t0
|
|
145
|
+
index_size = algo.get_memory_usage() - memory_usage_before
|
|
146
|
+
print('Built index in', build_time)
|
|
147
|
+
print('Index size: ', index_size)
|
|
148
|
+
|
|
149
|
+
query_argument_groups = definition.query_argument_groups
|
|
150
|
+
# Make sure that algorithms with no query argument groups still get run
|
|
151
|
+
# once by providing them with a single, empty, harmless group
|
|
152
|
+
if not query_argument_groups:
|
|
153
|
+
query_argument_groups = [[]]
|
|
154
|
+
|
|
155
|
+
if not build_only:
|
|
156
|
+
print('got %d queries' % len(X_test))
|
|
157
|
+
per_client = len(X_test) // num_clients
|
|
158
|
+
offset = per_client * (id - 1)
|
|
159
|
+
if (num_clients != id):
|
|
160
|
+
X_test = X_test[offset : offset + per_client]
|
|
161
|
+
else:
|
|
162
|
+
X_test = X_test[offset:]
|
|
163
|
+
print('running %d out of them' % len(X_test))
|
|
164
|
+
|
|
165
|
+
for pos, query_arguments in enumerate(query_argument_groups, 1):
|
|
166
|
+
print("Running query argument group %d of %d..." %
|
|
167
|
+
(pos, len(query_argument_groups)))
|
|
168
|
+
if query_arguments:
|
|
169
|
+
algo.set_query_arguments(*query_arguments)
|
|
170
|
+
if hybrid_buckets:
|
|
171
|
+
text = hybrid_buckets[D.attrs['selected_bucket']]['text'].decode()
|
|
172
|
+
print("setting hybrid text query", text)
|
|
173
|
+
algo.set_hybrid_query(text)
|
|
174
|
+
descriptor, results = run_individual_query(
|
|
175
|
+
algo, X_train, X_test, distance, count, run_count, batch)
|
|
176
|
+
if test_only:
|
|
177
|
+
try:
|
|
178
|
+
fn = get_result_filename(dataset, count)
|
|
179
|
+
fn = os.path.join(fn, definition.algorithm, 'build_stats')
|
|
180
|
+
f = h5py.File(fn, 'r')
|
|
181
|
+
descriptor["build_time"] = f.attrs["build_time"]
|
|
182
|
+
descriptor["index_size"] = f.attrs["index_size"]
|
|
183
|
+
f.close()
|
|
184
|
+
except:
|
|
185
|
+
descriptor["build_time"] = 0
|
|
186
|
+
descriptor["index_size"] = 0
|
|
187
|
+
else:
|
|
188
|
+
descriptor["build_time"] = build_time
|
|
189
|
+
descriptor["index_size"] = index_size
|
|
190
|
+
descriptor["algo"] = definition.algorithm
|
|
191
|
+
descriptor["dataset"] = dataset
|
|
192
|
+
store_results(dataset, count, definition, query_arguments,
|
|
193
|
+
descriptor, results, batch, id)
|
|
194
|
+
finally:
|
|
195
|
+
algo.done()
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def run_from_cmdline():
|
|
199
|
+
parser = argparse.ArgumentParser('''
|
|
200
|
+
|
|
201
|
+
NOTICE: You probably want to run.py rather than this script.
|
|
202
|
+
|
|
203
|
+
''')
|
|
204
|
+
parser.add_argument(
|
|
205
|
+
'--dataset',
|
|
206
|
+
choices=DATASETS.keys(),
|
|
207
|
+
help=f'Dataset to benchmark on.',
|
|
208
|
+
required=True)
|
|
209
|
+
parser.add_argument(
|
|
210
|
+
'--algorithm',
|
|
211
|
+
help='Name of algorithm for saving the results.',
|
|
212
|
+
required=True)
|
|
213
|
+
parser.add_argument(
|
|
214
|
+
'--module',
|
|
215
|
+
help='Python module containing algorithm. E.g. "ann_benchmarks.algorithms.annoy"',
|
|
216
|
+
required=True)
|
|
217
|
+
parser.add_argument(
|
|
218
|
+
'--constructor',
|
|
219
|
+
help='Constructer to load from modulel. E.g. "Annoy"',
|
|
220
|
+
required=True)
|
|
221
|
+
parser.add_argument(
|
|
222
|
+
'--count',
|
|
223
|
+
help='K: Number of nearest neighbours for the algorithm to return.',
|
|
224
|
+
required=True,
|
|
225
|
+
type=int)
|
|
226
|
+
parser.add_argument(
|
|
227
|
+
'--runs',
|
|
228
|
+
help='Number of times to run the algorihm. Will use the fastest run-time over the bunch.',
|
|
229
|
+
required=True,
|
|
230
|
+
type=int)
|
|
231
|
+
parser.add_argument(
|
|
232
|
+
'--batch',
|
|
233
|
+
help='If flag included, algorithms will be run in batch mode, rather than "individual query" mode.',
|
|
234
|
+
action='store_true')
|
|
235
|
+
parser.add_argument(
|
|
236
|
+
'--build-only',
|
|
237
|
+
action='store_true',
|
|
238
|
+
help='building index only, not testing with queries')
|
|
239
|
+
parser.add_argument(
|
|
240
|
+
'--test-only',
|
|
241
|
+
action='store_true',
|
|
242
|
+
help='querying index only, not building it (should be built first)')
|
|
243
|
+
parser.add_argument(
|
|
244
|
+
'build',
|
|
245
|
+
help='JSON of arguments to pass to the constructor. E.g. ["angular", 100]'
|
|
246
|
+
)
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
'queries',
|
|
249
|
+
help='JSON of arguments to pass to the queries. E.g. [100]',
|
|
250
|
+
nargs='*',
|
|
251
|
+
default=[])
|
|
252
|
+
args = parser.parse_args()
|
|
253
|
+
algo_args = json.loads(args.build)
|
|
254
|
+
print(algo_args)
|
|
255
|
+
query_args = [json.loads(q) for q in args.queries]
|
|
256
|
+
|
|
257
|
+
definition = Definition(
|
|
258
|
+
algorithm=args.algorithm,
|
|
259
|
+
docker_tag=None, # not needed
|
|
260
|
+
module=args.module,
|
|
261
|
+
constructor=args.constructor,
|
|
262
|
+
arguments=algo_args,
|
|
263
|
+
query_argument_groups=query_args,
|
|
264
|
+
disabled=False
|
|
265
|
+
)
|
|
266
|
+
run(definition, args.dataset, args.count, args.runs, args.batch, args.build_only, args.test_only, 1, 1)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def run_docker(definition, dataset, count, runs, timeout, batch, cpu_limit,
|
|
270
|
+
mem_limit=None):
|
|
271
|
+
cmd = ['--dataset', dataset,
|
|
272
|
+
'--algorithm', definition.algorithm,
|
|
273
|
+
'--module', definition.module,
|
|
274
|
+
'--constructor', definition.constructor,
|
|
275
|
+
'--runs', str(runs),
|
|
276
|
+
'--count', str(count)]
|
|
277
|
+
if batch:
|
|
278
|
+
cmd += ['--batch']
|
|
279
|
+
cmd.append(json.dumps(definition.arguments))
|
|
280
|
+
cmd += [json.dumps(qag) for qag in definition.query_argument_groups]
|
|
281
|
+
|
|
282
|
+
client = docker.from_env()
|
|
283
|
+
if mem_limit is None:
|
|
284
|
+
mem_limit = psutil.virtual_memory().available
|
|
285
|
+
|
|
286
|
+
container = client.containers.run(
|
|
287
|
+
definition.docker_tag,
|
|
288
|
+
cmd,
|
|
289
|
+
volumes={
|
|
290
|
+
os.path.abspath('ann_benchmarks'):
|
|
291
|
+
{'bind': '/home/app/ann_benchmarks', 'mode': 'ro'},
|
|
292
|
+
os.path.abspath('data'):
|
|
293
|
+
{'bind': '/home/app/data', 'mode': 'ro'},
|
|
294
|
+
os.path.abspath('results'):
|
|
295
|
+
{'bind': '/home/app/results', 'mode': 'rw'},
|
|
296
|
+
},
|
|
297
|
+
cpuset_cpus=cpu_limit,
|
|
298
|
+
mem_limit=mem_limit,
|
|
299
|
+
detach=True)
|
|
300
|
+
logger = logging.getLogger(f"annb.{container.short_id}")
|
|
301
|
+
|
|
302
|
+
logger.info('Created container %s: CPU limit %s, mem limit %s, timeout %d, command %s' % \
|
|
303
|
+
(container.short_id, cpu_limit, mem_limit, timeout, cmd))
|
|
304
|
+
|
|
305
|
+
def stream_logs():
|
|
306
|
+
for line in container.logs(stream=True):
|
|
307
|
+
logger.info(colors.color(line.decode().rstrip(), fg='blue'))
|
|
308
|
+
|
|
309
|
+
t = threading.Thread(target=stream_logs, daemon=True)
|
|
310
|
+
t.start()
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
return_value = container.wait(timeout=timeout)
|
|
314
|
+
_handle_container_return_value(return_value, container, logger)
|
|
315
|
+
except:
|
|
316
|
+
logger.error('Container.wait for container %s failed with exception' % container.short_id)
|
|
317
|
+
traceback.print_exc()
|
|
318
|
+
finally:
|
|
319
|
+
container.remove(force=True)
|
|
320
|
+
|
|
321
|
+
def _handle_container_return_value(return_value, container, logger):
|
|
322
|
+
base_msg = 'Child process for container %s' % (container.short_id)
|
|
323
|
+
if type(return_value) is dict: # The return value from container.wait changes from int to dict in docker 3.0.0
|
|
324
|
+
error_msg = return_value['Error']
|
|
325
|
+
exit_code = return_value['StatusCode']
|
|
326
|
+
msg = base_msg + 'returned exit code %d with message %s' %(exit_code, error_msg)
|
|
327
|
+
else:
|
|
328
|
+
exit_code = return_value
|
|
329
|
+
msg = base_msg + 'returned exit code %d' % (exit_code)
|
|
330
|
+
|
|
331
|
+
if exit_code not in [0, None]:
|
|
332
|
+
logger.error(colors.color(container.logs().decode(), fg='red'))
|
|
333
|
+
logger.error(msg)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from ann_benchmarks.datasets import DATASETS, get_dataset_fn
|
|
3
|
+
|
|
4
|
+
if __name__ == "__main__":
|
|
5
|
+
parser = argparse.ArgumentParser()
|
|
6
|
+
parser.add_argument(
|
|
7
|
+
'--dataset',
|
|
8
|
+
choices=DATASETS.keys(),
|
|
9
|
+
required=True)
|
|
10
|
+
args = parser.parse_args()
|
|
11
|
+
fn = get_dataset_fn(args.dataset)
|
|
12
|
+
DATASETS[args.dataset](fn)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from operator import ne
|
|
2
|
+
import click
|
|
3
|
+
from ann_benchmarks.datasets import get_dataset, DATASETS
|
|
4
|
+
from ann_benchmarks.algorithms.bruteforce import BruteForceBLAS
|
|
5
|
+
import struct
|
|
6
|
+
import numpy as np
|
|
7
|
+
import click
|
|
8
|
+
import h5py
|
|
9
|
+
from joblib import Parallel, delayed
|
|
10
|
+
import multiprocessing
|
|
11
|
+
import scipy.spatial
|
|
12
|
+
|
|
13
|
+
def calc_i(i, x, bf, test, neighbors, distances, count, orig_ids):
|
|
14
|
+
if i % 1000 == 0:
|
|
15
|
+
print('%d/%d...' % (i, len(test)))
|
|
16
|
+
res = list(bf.query_with_distances(x, count))
|
|
17
|
+
res.sort(key=lambda t: t[-1])
|
|
18
|
+
neighbors[i] = [orig_ids[j] for j, _ in res]
|
|
19
|
+
distances[i] = [d for _, d in res]
|
|
20
|
+
|
|
21
|
+
def create_buckets(train):
|
|
22
|
+
bucket_0_5 = []
|
|
23
|
+
bucket_1 = []
|
|
24
|
+
bucket_2 = []
|
|
25
|
+
bucket_5 = []
|
|
26
|
+
bucket_10 = []
|
|
27
|
+
bucket_20 = []
|
|
28
|
+
bucket_50 = []
|
|
29
|
+
other_bucket = []
|
|
30
|
+
buckets = [bucket_0_5, bucket_1, bucket_2, bucket_5, bucket_10, bucket_20, bucket_50, other_bucket]
|
|
31
|
+
bucket_names=['0.5', '1', '2', '5', '10', '20', '50', 'other']
|
|
32
|
+
for i in range(train.shape[0]):
|
|
33
|
+
if i % 200 == 19: # 0.5%
|
|
34
|
+
bucket_0_5.append(i)
|
|
35
|
+
elif i % 100 == 17: # 1%
|
|
36
|
+
bucket_1.append(i)
|
|
37
|
+
elif i % 50 == 9: # 2%
|
|
38
|
+
bucket_2.append(i)
|
|
39
|
+
elif i % 20 == 7: # 5%
|
|
40
|
+
bucket_5.append(i)
|
|
41
|
+
elif i % 10 == 3: # 10%
|
|
42
|
+
bucket_10.append(i)
|
|
43
|
+
elif i % 2 == 0: # 50%
|
|
44
|
+
bucket_50.append(i)
|
|
45
|
+
elif i % 5 <= 1: # 20%
|
|
46
|
+
bucket_20.append(i)
|
|
47
|
+
else:
|
|
48
|
+
other_bucket.append(i)
|
|
49
|
+
print(len(bucket_0_5), len(bucket_1), len(bucket_2), len(bucket_5), len(bucket_10), len(bucket_20), len(bucket_50), len(other_bucket))
|
|
50
|
+
numeric_values = {}
|
|
51
|
+
text_values = {}
|
|
52
|
+
for i, bucket_name in enumerate(bucket_names):
|
|
53
|
+
numeric_values[bucket_name] = i
|
|
54
|
+
text_values[bucket_name] = f'text_{i}'
|
|
55
|
+
print(numeric_values)
|
|
56
|
+
print(text_values)
|
|
57
|
+
return buckets, bucket_names, numeric_values, text_values
|
|
58
|
+
|
|
59
|
+
@click.command()
|
|
60
|
+
@click.option('--data_set', type=click.Choice(DATASETS.keys(), case_sensitive=False), default='glove-100-angular')
|
|
61
|
+
@click.option('--percentile', type=click.Choice(['0.5', '1', '2', '5', '10', '20', '50'], case_sensitive=False), default=None)
|
|
62
|
+
def create_ds(data_set, percentile):
|
|
63
|
+
ds, dimension= get_dataset(data_set)
|
|
64
|
+
train = ds['train']
|
|
65
|
+
test = ds['test']
|
|
66
|
+
distance = ds.attrs['distance']
|
|
67
|
+
count=len(ds['neighbors'][0])
|
|
68
|
+
print(count)
|
|
69
|
+
print(train.shape)
|
|
70
|
+
buckets, bucket_names, numeric_values, text_values = create_buckets(train)
|
|
71
|
+
|
|
72
|
+
if percentile is not None:
|
|
73
|
+
i = ['0.5', '1', '2', '5', '10', '20', '50'].index(percentile)
|
|
74
|
+
bucket = buckets[i]
|
|
75
|
+
fn=f'{data_set}-hybrid-{bucket_names[i]}.hdf5'
|
|
76
|
+
with h5py.File(fn, 'w') as f:
|
|
77
|
+
f.attrs['type'] = 'dense'
|
|
78
|
+
f.attrs['distance'] = ds.attrs['distance']
|
|
79
|
+
f.attrs['dimension'] = len(test[0])
|
|
80
|
+
f.attrs['point_type'] = 'float'
|
|
81
|
+
f.attrs['bucket_names'] = bucket_names
|
|
82
|
+
f.attrs['selected_bucket'] = bucket_names[i]
|
|
83
|
+
for bucket_name in bucket_names:
|
|
84
|
+
grp = f.create_group(bucket_name)
|
|
85
|
+
grp["text"] = text_values[bucket_name]
|
|
86
|
+
grp["number"] = numeric_values[bucket_name]
|
|
87
|
+
|
|
88
|
+
f.create_dataset('train', train.shape, dtype=train.dtype)[:] = train
|
|
89
|
+
f.create_dataset('test', test.shape, dtype=test.dtype)[:] = test
|
|
90
|
+
# Write the id buckets so on ingestion we will know what data to assign for each id.
|
|
91
|
+
|
|
92
|
+
for j, id_bucket in enumerate(buckets):
|
|
93
|
+
np_bucket = np.array(id_bucket, dtype=np.int32)
|
|
94
|
+
f.create_dataset(f'{bucket_names[j]}_ids', np_bucket.shape, dtype=np_bucket.dtype)[:] = np_bucket
|
|
95
|
+
|
|
96
|
+
neighbors = f.create_dataset(f'neighbors', (len(test), count), dtype='i')
|
|
97
|
+
distances = f.create_dataset(f'distances', (len(test), count), dtype='f')
|
|
98
|
+
|
|
99
|
+
# Generate ground truth only for the relevan bucket.
|
|
100
|
+
train_bucket = np.array(bucket, dtype = np.int32)
|
|
101
|
+
train_set = np.empty((len(bucket), train.shape[1]), dtype=np.float32)
|
|
102
|
+
for id in range(len(bucket)):
|
|
103
|
+
train_set[id] = train[bucket[id]]
|
|
104
|
+
bf = BruteForceBLAS(distance, precision=train.dtype)
|
|
105
|
+
bf.fit(train_set)
|
|
106
|
+
Parallel(n_jobs=multiprocessing.cpu_count(), require='sharedmem')(delayed(calc_i)(i, x, bf, test, neighbors, distances, count, train_bucket) for i, x in enumerate(test))
|
|
107
|
+
|
|
108
|
+
else:
|
|
109
|
+
for i, bucket in enumerate(buckets):
|
|
110
|
+
fn=f'{data_set}-hybrid-{bucket_names[i]}.hdf5'
|
|
111
|
+
with h5py.File(fn, 'w') as f:
|
|
112
|
+
f.attrs['type'] = 'dense'
|
|
113
|
+
f.attrs['distance'] = ds.attrs['distance']
|
|
114
|
+
f.attrs['dimension'] = len(test[0])
|
|
115
|
+
f.attrs['point_type'] = 'float'
|
|
116
|
+
f.attrs['bucket_names'] = bucket_names
|
|
117
|
+
f.attrs['selected_bucket'] = bucket_names[i]
|
|
118
|
+
for bucket_name in bucket_names:
|
|
119
|
+
grp = f.create_group(bucket_name)
|
|
120
|
+
grp["text"] = text_values[bucket_name]
|
|
121
|
+
grp["number"] = numeric_values[bucket_name]
|
|
122
|
+
|
|
123
|
+
f.create_dataset('train', train.shape, dtype=train.dtype)[:] = train
|
|
124
|
+
f.create_dataset('test', test.shape, dtype=test.dtype)[:] = test
|
|
125
|
+
# Write the id buckets so on ingestion we will know what data to assign for each id.
|
|
126
|
+
for j, id_bucket in enumerate(buckets):
|
|
127
|
+
np_bucket = np.array(id_bucket, dtype=np.int32)
|
|
128
|
+
f.create_dataset(f'{bucket_names[j]}_ids', np_bucket.shape, dtype=np_bucket.dtype)[:] = np_bucket
|
|
129
|
+
|
|
130
|
+
neighbors = f.create_dataset(f'neighbors', (len(test), count), dtype='i')
|
|
131
|
+
distances = f.create_dataset(f'distances', (len(test), count), dtype='f')
|
|
132
|
+
|
|
133
|
+
# Generate ground truth only for the relevan bucket.
|
|
134
|
+
train_bucket = np.array(bucket, dtype = np.int32)
|
|
135
|
+
train_set = np.empty((len(bucket), train.shape[1]), dtype=np.float32)
|
|
136
|
+
for id in range(len(bucket)):
|
|
137
|
+
train_set[id] = train[bucket[id]]
|
|
138
|
+
print(train_set.shape)
|
|
139
|
+
bf = BruteForceBLAS(distance, precision=train.dtype)
|
|
140
|
+
bf.fit(train_set)
|
|
141
|
+
Parallel(n_jobs=multiprocessing.cpu_count(), require='sharedmem')(delayed(calc_i)(i, x, bf, test, neighbors, distances, count, train_bucket) for i, x in enumerate(test))
|
|
142
|
+
print(neighbors[1])
|
|
143
|
+
print(distances[1])
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
if __name__ == "__main__":
|
|
147
|
+
create_ds()
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from ann_benchmarks.algorithms.bruteforce import BruteForceBLAS
|
|
2
|
+
import struct
|
|
3
|
+
import numpy as np
|
|
4
|
+
import click
|
|
5
|
+
import h5py
|
|
6
|
+
from joblib import Parallel, delayed
|
|
7
|
+
import multiprocessing
|
|
8
|
+
|
|
9
|
+
def read_fbin(filename, start_idx=0, chunk_size=None):
|
|
10
|
+
""" Read *.fbin file that contains float32 vectors
|
|
11
|
+
Args:
|
|
12
|
+
:param filename (str): path to *.fbin file
|
|
13
|
+
:param start_idx (int): start reading vectors from this index
|
|
14
|
+
:param chunk_size (int): number of vectors to read.
|
|
15
|
+
If None, read all vectors
|
|
16
|
+
Returns:
|
|
17
|
+
Array of float32 vectors (numpy.ndarray)
|
|
18
|
+
"""
|
|
19
|
+
with open(filename, "rb") as f:
|
|
20
|
+
nvecs, dim = np.fromfile(f, count=2, dtype=np.int32)
|
|
21
|
+
nvecs = (nvecs - start_idx) if chunk_size is None else chunk_size
|
|
22
|
+
arr = np.fromfile(f, count=nvecs * dim, dtype=np.float32,
|
|
23
|
+
offset=start_idx * 4 * dim)
|
|
24
|
+
return arr.reshape(nvecs, dim)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def read_ibin(filename, start_idx=0, chunk_size=None):
|
|
28
|
+
""" Read *.ibin file that contains int32 vectors
|
|
29
|
+
Args:
|
|
30
|
+
:param filename (str): path to *.ibin file
|
|
31
|
+
:param start_idx (int): start reading vectors from this index
|
|
32
|
+
:param chunk_size (int): number of vectors to read.
|
|
33
|
+
If None, read all vectors
|
|
34
|
+
Returns:
|
|
35
|
+
Array of int32 vectors (numpy.ndarray)
|
|
36
|
+
"""
|
|
37
|
+
with open(filename, "rb") as f:
|
|
38
|
+
nvecs, dim = np.fromfile(f, count=2, dtype=np.int32)
|
|
39
|
+
nvecs = (nvecs - start_idx) if chunk_size is None else chunk_size
|
|
40
|
+
arr = np.fromfile(f, count=nvecs * dim, dtype=np.int32,
|
|
41
|
+
offset=start_idx * 4 * dim)
|
|
42
|
+
return arr.reshape(nvecs, dim)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def write_fbin(filename, vecs):
|
|
46
|
+
""" Write an array of float32 vectors to *.fbin file
|
|
47
|
+
Args:s
|
|
48
|
+
:param filename (str): path to *.fbin file
|
|
49
|
+
:param vecs (numpy.ndarray): array of float32 vectors to write
|
|
50
|
+
"""
|
|
51
|
+
assert len(vecs.shape) == 2, "Input array must have 2 dimensions"
|
|
52
|
+
with open(filename, "wb") as f:
|
|
53
|
+
nvecs, dim = vecs.shape
|
|
54
|
+
f.write(struct.pack('<i', nvecs))
|
|
55
|
+
f.write(struct.pack('<i', dim))
|
|
56
|
+
vecs.astype('float32').flatten().tofile(f)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def write_ibin(filename, vecs):
|
|
60
|
+
""" Write an array of int32 vectors to *.ibin file
|
|
61
|
+
Args:
|
|
62
|
+
:param filename (str): path to *.ibin file
|
|
63
|
+
:param vecs (numpy.ndarray): array of int32 vectors to write
|
|
64
|
+
"""
|
|
65
|
+
assert len(vecs.shape) == 2, "Input array must have 2 dimensions"
|
|
66
|
+
with open(filename, "wb") as f:
|
|
67
|
+
nvecs, dim = vecs.shape
|
|
68
|
+
f.write(struct.pack('<i', nvecs))
|
|
69
|
+
f.write(struct.pack('<i', dim))
|
|
70
|
+
vecs.astype('int32').flatten().tofile(f)
|
|
71
|
+
|
|
72
|
+
def calc_i(i, x, bf, test, neighbors, distances, count):
|
|
73
|
+
if i % 1000 == 0:
|
|
74
|
+
print('%d/%d...' % (i, len(test)))
|
|
75
|
+
res = list(bf.query_with_distances(x, count))
|
|
76
|
+
res.sort(key=lambda t: t[-1])
|
|
77
|
+
neighbors[i] = [j for j, _ in res]
|
|
78
|
+
distances[i] = [d for _, d in res]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def calc(bf, test, neighbors, distances, count):
|
|
82
|
+
Parallel(n_jobs=multiprocessing.cpu_count(), require='sharedmem')(delayed(calc_i)(i, x, bf, test, neighbors, distances, count) for i, x in enumerate(test))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def write_output(train, test, fn, distance, point_type='float', count=100):
|
|
86
|
+
n = 0
|
|
87
|
+
f = h5py.File(fn, 'w')
|
|
88
|
+
f.attrs['type'] = 'dense'
|
|
89
|
+
f.attrs['distance'] = distance
|
|
90
|
+
f.attrs['dimension'] = len(train[0])
|
|
91
|
+
f.attrs['point_type'] = point_type
|
|
92
|
+
print('train size: %9d * %4d' % train.shape)
|
|
93
|
+
print('test size: %9d * %4d' % test.shape)
|
|
94
|
+
f.create_dataset('train', (len(train), len(
|
|
95
|
+
train[0])), dtype=train.dtype)[:] = train
|
|
96
|
+
f.create_dataset('test', (len(test), len(
|
|
97
|
+
test[0])), dtype=test.dtype)[:] = test
|
|
98
|
+
neighbors = f.create_dataset('neighbors', (len(test), count), dtype='i')
|
|
99
|
+
distances = f.create_dataset('distances', (len(test), count), dtype='f')
|
|
100
|
+
bf = BruteForceBLAS(distance, precision=train.dtype)
|
|
101
|
+
|
|
102
|
+
bf.fit(train)
|
|
103
|
+
calc(bf, test, neighbors, distances, count)
|
|
104
|
+
f.close()
|
|
105
|
+
|
|
106
|
+
@click.command()
|
|
107
|
+
@click.option('--size', default=10, help='Number of vectors in milions.')
|
|
108
|
+
@click.option('--distance', default='angular', help='distance metric.')
|
|
109
|
+
@click.option('--test_set', required=True, type=str)
|
|
110
|
+
@click.option('--train_set', required=True, type=str)
|
|
111
|
+
def create_ds(size, distance, test_set, train_set):
|
|
112
|
+
test_set = read_fbin(test_set)
|
|
113
|
+
train_set= read_fbin(train_set, chunk_size=size*1000000)
|
|
114
|
+
write_output(train=train_set, test=test_set, fn=f'Text-to-Image-{size}M.hdf5', distance=distance, point_type='float', count=100)
|
|
115
|
+
|
|
116
|
+
if __name__ == "__main__":
|
|
117
|
+
create_ds()
|