nomic 3.0.32__tar.gz → 3.0.34__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nomic might be problematic. Click here for more details.
- {nomic-3.0.32 → nomic-3.0.34}/PKG-INFO +1 -1
- {nomic-3.0.32 → nomic-3.0.34}/nomic/data_inference.py +1 -1
- {nomic-3.0.32 → nomic-3.0.34}/nomic/data_operations.py +240 -140
- {nomic-3.0.32 → nomic-3.0.34}/nomic/dataset.py +59 -117
- {nomic-3.0.32 → nomic-3.0.34}/nomic/utils.py +33 -15
- {nomic-3.0.32 → nomic-3.0.34}/nomic.egg-info/PKG-INFO +1 -1
- {nomic-3.0.32 → nomic-3.0.34}/setup.py +1 -1
- {nomic-3.0.32 → nomic-3.0.34}/README.md +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/__init__.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/atlas.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/aws/__init__.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/aws/sagemaker.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/cli.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/embed.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/pl_callbacks/__init__.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/pl_callbacks/pl_callback.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic/settings.py +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic.egg-info/SOURCES.txt +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic.egg-info/dependency_links.txt +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic.egg-info/entry_points.txt +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic.egg-info/requires.txt +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/nomic.egg-info/top_level.txt +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/pyproject.toml +0 -0
- {nomic-3.0.32 → nomic-3.0.34}/setup.cfg +0 -0
|
@@ -88,7 +88,7 @@ class NomicTopicOptions(BaseModel):
|
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
build_topic_model: bool = True
|
|
91
|
-
topic_label_field: Optional[str] = Field(default=None
|
|
91
|
+
topic_label_field: Optional[str] = Field(default=None)
|
|
92
92
|
cluster_method: str = "fast"
|
|
93
93
|
enforce_topic_hierarchy: bool = False
|
|
94
94
|
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
3
|
import json
|
|
4
|
-
import os
|
|
5
4
|
from collections import defaultdict
|
|
6
5
|
from datetime import datetime
|
|
7
6
|
from pathlib import Path
|
|
8
|
-
from typing import Dict, Iterable, List, Optional, Tuple
|
|
7
|
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import numpy as np
|
|
11
10
|
import pandas as pd
|
|
@@ -16,9 +15,6 @@ from pyarrow import compute as pc
|
|
|
16
15
|
from pyarrow import feather
|
|
17
16
|
from tqdm import tqdm
|
|
18
17
|
|
|
19
|
-
from .settings import EMBEDDING_PAGINATION_LIMIT
|
|
20
|
-
from .utils import download_feather
|
|
21
|
-
|
|
22
18
|
|
|
23
19
|
class AtlasMapDuplicates:
|
|
24
20
|
"""
|
|
@@ -30,22 +26,57 @@ class AtlasMapDuplicates:
|
|
|
30
26
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
31
27
|
self.projection = projection
|
|
32
28
|
self.id_field = self.projection.dataset.id_field
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
self.
|
|
41
|
-
|
|
42
|
-
|
|
29
|
+
|
|
30
|
+
duplicate_columns = [
|
|
31
|
+
(field, sidecar)
|
|
32
|
+
for field, sidecar in self.projection._registered_columns
|
|
33
|
+
if field.startswith("_duplicate_class")
|
|
34
|
+
]
|
|
35
|
+
cluster_columns = [
|
|
36
|
+
(field, sidecar) for field, sidecar in self.projection._registered_columns if field.startswith("_cluster")
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
assert len(duplicate_columns) > 0, "Duplicate detection has not yet been run on this map."
|
|
40
|
+
|
|
41
|
+
self._duplicate_column = duplicate_columns[0]
|
|
42
|
+
self._cluster_column = cluster_columns[0]
|
|
43
|
+
self._tb = None
|
|
44
|
+
|
|
45
|
+
def _load_duplicates(self):
|
|
46
|
+
"""
|
|
47
|
+
Loads duplicates from the feather tree.
|
|
48
|
+
"""
|
|
49
|
+
tbs = []
|
|
50
|
+
duplicate_sidecar = self._duplicate_column[1]
|
|
51
|
+
self.duplicate_field = self._duplicate_column[0].lstrip("_")
|
|
52
|
+
self.cluster_field = self._cluster_column[0].lstrip("_")
|
|
53
|
+
logger.info("Loading duplicates")
|
|
54
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
55
|
+
# Use datum id as root table
|
|
56
|
+
tb = feather.read_table(
|
|
57
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
43
58
|
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
59
|
+
path = self.projection.tile_destination
|
|
60
|
+
|
|
61
|
+
if duplicate_sidecar == "":
|
|
62
|
+
path = path / Path(key).with_suffix(".feather")
|
|
63
|
+
else:
|
|
64
|
+
path = path / Path(key).with_suffix(f".{duplicate_sidecar}.feather")
|
|
65
|
+
|
|
66
|
+
duplicate_tb = feather.read_table(path, memory_map=True)
|
|
67
|
+
for field in (self._duplicate_column[0], self._cluster_column[0]):
|
|
68
|
+
tb = tb.append_column(field, duplicate_tb[field])
|
|
69
|
+
tbs.append(tb)
|
|
70
|
+
self._tb = pa.concat_tables(tbs).rename_columns([self.id_field, self.duplicate_field, self.cluster_field])
|
|
71
|
+
|
|
72
|
+
def _download_duplicates(self):
|
|
73
|
+
"""
|
|
74
|
+
Downloads the feather tree for duplicates.
|
|
75
|
+
"""
|
|
76
|
+
logger.info("Downloading duplicates")
|
|
77
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
78
|
+
assert self._cluster_column[1] == self._duplicate_column[1], "Cluster and duplicate should be in same sidecar"
|
|
79
|
+
self.projection._download_sidecar(self._duplicate_column[1], overwrite=False)
|
|
49
80
|
|
|
50
81
|
@property
|
|
51
82
|
def df(self) -> pd.DataFrame:
|
|
@@ -61,6 +92,10 @@ class AtlasMapDuplicates:
|
|
|
61
92
|
This table is memmapped from the underlying files and is the most efficient way to
|
|
62
93
|
access duplicate information.
|
|
63
94
|
"""
|
|
95
|
+
if isinstance(self._tb, pa.Table):
|
|
96
|
+
return self._tb
|
|
97
|
+
self._download_duplicates()
|
|
98
|
+
self._load_duplicates()
|
|
64
99
|
return self._tb
|
|
65
100
|
|
|
66
101
|
def deletion_candidates(self) -> List[str]:
|
|
@@ -93,35 +128,69 @@ class AtlasMapTopics:
|
|
|
93
128
|
self.id_field = self.projection.dataset.id_field
|
|
94
129
|
self._metadata = None
|
|
95
130
|
self._hierarchy = None
|
|
131
|
+
self._topic_columns = [
|
|
132
|
+
column for column in self.projection._registered_columns if column[0].startswith("_topic_depth_")
|
|
133
|
+
]
|
|
134
|
+
assert len(self._topic_columns) > 0, "Topic modeling has not yet been run on this map."
|
|
135
|
+
self.depth = len(self._topic_columns)
|
|
136
|
+
self._tb = None
|
|
96
137
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
138
|
+
def _load_topics(self):
|
|
139
|
+
"""
|
|
140
|
+
Loads topics from the feather tree.
|
|
141
|
+
"""
|
|
142
|
+
integer_topics = False
|
|
143
|
+
# pd.Series to match pd typing
|
|
144
|
+
label_df: Optional[Union[pd.DataFrame, pd.Series]] = None
|
|
145
|
+
if "int" in self._topic_columns[0][0]:
|
|
146
|
+
integer_topics = True
|
|
147
|
+
label_df = self.metadata[["topic_id", "depth", "topic_short_description"]]
|
|
148
|
+
tbs = []
|
|
149
|
+
# Should just be one sidecar
|
|
150
|
+
topic_sidecar = set([sidecar for _, sidecar in self._topic_columns]).pop()
|
|
151
|
+
logger.info("Loading topics")
|
|
152
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
153
|
+
# Use datum id as root table
|
|
154
|
+
tb = feather.read_table(
|
|
155
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
156
|
+
)
|
|
157
|
+
path = self.projection.tile_destination
|
|
158
|
+
if topic_sidecar == "":
|
|
159
|
+
path = path / Path(key).with_suffix(".feather")
|
|
160
|
+
else:
|
|
161
|
+
path = path / Path(key).with_suffix(f".{topic_sidecar}.feather")
|
|
162
|
+
|
|
163
|
+
topic_tb = feather.read_table(path, memory_map=True)
|
|
164
|
+
# Do this in depth order
|
|
165
|
+
for d in range(1, self.depth + 1):
|
|
166
|
+
column = f"_topic_depth_{d}"
|
|
167
|
+
if integer_topics:
|
|
108
168
|
column = f"_topic_depth_{d}_int"
|
|
109
|
-
topic_ids_to_label =
|
|
169
|
+
topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
|
|
170
|
+
assert label_df is not None
|
|
110
171
|
topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
|
|
111
172
|
topic_ids_to_label, on="topic_id", how="right"
|
|
112
173
|
)
|
|
113
174
|
new_column = f"_topic_depth_{d}"
|
|
114
|
-
|
|
175
|
+
tb = tb.append_column(
|
|
115
176
|
new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
|
|
116
177
|
)
|
|
117
|
-
|
|
118
|
-
|
|
178
|
+
else:
|
|
179
|
+
tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
|
|
180
|
+
tbs.append(tb)
|
|
119
181
|
|
|
120
|
-
|
|
121
|
-
|
|
182
|
+
renamed_columns = [self.id_field] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
|
|
183
|
+
self._tb = pa.concat_tables(tbs).rename_columns(renamed_columns)
|
|
122
184
|
|
|
123
|
-
|
|
124
|
-
|
|
185
|
+
def _download_topics(self):
|
|
186
|
+
"""
|
|
187
|
+
Downloads the feather tree for topics.
|
|
188
|
+
"""
|
|
189
|
+
logger.info("Downloading topics")
|
|
190
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
191
|
+
topic_sidecars = set([sidecar for _, sidecar in self._topic_columns])
|
|
192
|
+
assert len(topic_sidecars) == 1, "Multiple topic sidecars found."
|
|
193
|
+
self.projection._download_sidecar(topic_sidecars.pop(), overwrite=False)
|
|
125
194
|
|
|
126
195
|
@property
|
|
127
196
|
def df(self) -> pd.DataFrame:
|
|
@@ -137,6 +206,10 @@ class AtlasMapTopics:
|
|
|
137
206
|
This table is memmapped from the underlying files and is the most efficient way to
|
|
138
207
|
access topic information.
|
|
139
208
|
"""
|
|
209
|
+
if isinstance(self._tb, pa.Table):
|
|
210
|
+
return self._tb
|
|
211
|
+
self._download_topics()
|
|
212
|
+
self._load_topics()
|
|
140
213
|
return self._tb
|
|
141
214
|
|
|
142
215
|
@property
|
|
@@ -260,8 +333,8 @@ class AtlasMapTopics:
|
|
|
260
333
|
A list of `{topic, count}` dictionaries, sorted from largest count to smallest count.
|
|
261
334
|
"""
|
|
262
335
|
data = AtlasMapData(self.projection, fields=[time_field])
|
|
263
|
-
time_data = data.
|
|
264
|
-
merged_tb = self.
|
|
336
|
+
time_data = data.tb.select([self.id_field, time_field])
|
|
337
|
+
merged_tb = self.tb.join(time_data, self.id_field, join_type="inner").combine_chunks()
|
|
265
338
|
|
|
266
339
|
del time_data # free up memory
|
|
267
340
|
|
|
@@ -376,8 +449,8 @@ class AtlasMapEmbeddings:
|
|
|
376
449
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
377
450
|
self.projection = projection
|
|
378
451
|
self.id_field = self.projection.dataset.id_field
|
|
379
|
-
self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, "x", "y"])
|
|
380
452
|
self.dataset = projection.dataset
|
|
453
|
+
self._tb: pa.Table = None
|
|
381
454
|
self._latent = None
|
|
382
455
|
|
|
383
456
|
@property
|
|
@@ -398,6 +471,31 @@ class AtlasMapEmbeddings:
|
|
|
398
471
|
|
|
399
472
|
Does not include high-dimensional embeddings.
|
|
400
473
|
"""
|
|
474
|
+
if isinstance(self._tb, pa.Table):
|
|
475
|
+
return self._tb
|
|
476
|
+
|
|
477
|
+
self._download_projected()
|
|
478
|
+
|
|
479
|
+
logger.info("Loading projected embeddings")
|
|
480
|
+
|
|
481
|
+
tbs = []
|
|
482
|
+
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
483
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
484
|
+
# Use datum id as root table
|
|
485
|
+
tb = feather.read_table(
|
|
486
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
487
|
+
)
|
|
488
|
+
path = self.projection.tile_destination
|
|
489
|
+
if coord_sidecar == "":
|
|
490
|
+
path = path / Path(key).with_suffix(".feather")
|
|
491
|
+
else:
|
|
492
|
+
path = path / Path(key).with_suffix(f".{coord_sidecar}.feather")
|
|
493
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
494
|
+
for col in carfile.column_names:
|
|
495
|
+
if col in ["x", "y"]:
|
|
496
|
+
tb = tb.append_column(col, carfile[col])
|
|
497
|
+
tbs.append(tb)
|
|
498
|
+
self._tb = pa.concat_tables(tbs)
|
|
401
499
|
return self._tb
|
|
402
500
|
|
|
403
501
|
@property
|
|
@@ -426,38 +524,40 @@ class AtlasMapEmbeddings:
|
|
|
426
524
|
downloaded_files_in_tile_order = self._download_latent()
|
|
427
525
|
assert len(downloaded_files_in_tile_order) > 0, "No embeddings found for this map."
|
|
428
526
|
all_embeddings = []
|
|
429
|
-
|
|
430
|
-
for path in downloaded_files_in_tile_order:
|
|
431
|
-
# Should there be more than 10, we need to sort by int values, not string values
|
|
527
|
+
logger.info("Loading latent embeddings")
|
|
528
|
+
for path in tqdm(downloaded_files_in_tile_order):
|
|
432
529
|
tb = feather.read_table(path, memory_map=True)
|
|
433
530
|
dims = tb["_embeddings"].type.list_size
|
|
434
531
|
all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
|
|
435
532
|
return np.vstack(all_embeddings)
|
|
436
533
|
|
|
534
|
+
def _download_projected(self) -> List[Path]:
|
|
535
|
+
"""
|
|
536
|
+
Downloads the feather tree for projection coordinates.
|
|
537
|
+
"""
|
|
538
|
+
logger.info("Downloading projected embeddings")
|
|
539
|
+
# Note that y coord should be in same sidecar
|
|
540
|
+
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
541
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
542
|
+
return self.projection._download_sidecar(coord_sidecar, overwrite=False)
|
|
543
|
+
|
|
437
544
|
def _download_latent(self) -> List[Path]:
|
|
438
545
|
"""
|
|
439
546
|
Downloads the feather tree for embeddings.
|
|
440
547
|
Returns the path to downloaded embeddings.
|
|
441
548
|
"""
|
|
442
549
|
# TODO: Is size of the embedding files (several hundreds of MBs) going to be a problem here?
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
for quad in tqdm(all_quads):
|
|
455
|
-
path = quad.with_suffix(".embeddings.feather")
|
|
456
|
-
# WARNING: Potentially large data request here
|
|
457
|
-
quadtree_loc = Path(*path.parts[-3:])
|
|
458
|
-
download_feather(root_url / quadtree_loc, path, headers=self.dataset.header, overwrite=False)
|
|
459
|
-
downloaded_files_in_tile_order.append(path)
|
|
460
|
-
return downloaded_files_in_tile_order
|
|
550
|
+
logger.info("Downloading latent embeddings")
|
|
551
|
+
embedding_sidecar = None
|
|
552
|
+
for field, sidecar in self.projection._registered_columns:
|
|
553
|
+
# NOTE: be _embeddings or _embedding
|
|
554
|
+
if field == "_embeddings":
|
|
555
|
+
embedding_sidecar = sidecar
|
|
556
|
+
break
|
|
557
|
+
|
|
558
|
+
if embedding_sidecar is None:
|
|
559
|
+
raise ValueError("No embeddings found for this map.")
|
|
560
|
+
return self.projection._download_sidecar(embedding_sidecar, overwrite=False)
|
|
461
561
|
|
|
462
562
|
def vector_search(
|
|
463
563
|
self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5
|
|
@@ -571,12 +671,15 @@ class AtlasMapTags:
|
|
|
571
671
|
self.projection = projection
|
|
572
672
|
self.dataset = projection.dataset
|
|
573
673
|
self.id_field = self.projection.dataset.id_field
|
|
574
|
-
# Pre-fetch
|
|
575
|
-
|
|
674
|
+
# Pre-fetch datum ids first upon initialization
|
|
675
|
+
try:
|
|
676
|
+
self.projection._download_sidecar("datum_id")
|
|
677
|
+
except Exception:
|
|
678
|
+
raise ValueError("Failed to fetch datum ids which is required to load tags.")
|
|
576
679
|
self.auto_cleanup = auto_cleanup
|
|
577
680
|
|
|
578
681
|
@property
|
|
579
|
-
def df(self, overwrite:
|
|
682
|
+
def df(self, overwrite: bool = False) -> pd.DataFrame:
|
|
580
683
|
"""
|
|
581
684
|
Pandas DataFrame mapping each data point to its tags.
|
|
582
685
|
"""
|
|
@@ -587,16 +690,13 @@ class AtlasMapTags:
|
|
|
587
690
|
for tag in tags:
|
|
588
691
|
self._download_tag(tag["tag_name"], overwrite=overwrite)
|
|
589
692
|
tbs = []
|
|
590
|
-
|
|
591
|
-
for
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
path = self.projection.tile_destination / Path(datum_id_filename)
|
|
595
|
-
tb = feather.read_table(path, memory_map=True)
|
|
693
|
+
logger.info("Loading tags")
|
|
694
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
695
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
696
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
596
697
|
for tag in tags:
|
|
597
698
|
tag_definition_id = tag["tag_definition_id"]
|
|
598
|
-
|
|
599
|
-
path = self.projection.tile_destination / Path(tag_filename)
|
|
699
|
+
path = self.projection.tile_destination / Path(key).with_suffix(f"._tag.{tag_definition_id}.feather")
|
|
600
700
|
tag_tb = feather.read_table(path, memory_map=True)
|
|
601
701
|
bitmask = None
|
|
602
702
|
if "all_set" in tag_tb.column_names:
|
|
@@ -635,7 +735,7 @@ class AtlasMapTags:
|
|
|
635
735
|
keep_tags.append(tag)
|
|
636
736
|
return keep_tags
|
|
637
737
|
|
|
638
|
-
def get_datums_in_tag(self, tag_name: str, overwrite:
|
|
738
|
+
def get_datums_in_tag(self, tag_name: str, overwrite: bool = False):
|
|
639
739
|
"""
|
|
640
740
|
Returns the datum ids in a given tag.
|
|
641
741
|
|
|
@@ -646,9 +746,9 @@ class AtlasMapTags:
|
|
|
646
746
|
Returns:
|
|
647
747
|
List of datum ids.
|
|
648
748
|
"""
|
|
649
|
-
|
|
749
|
+
tag_paths = self._download_tag(tag_name, overwrite=overwrite)
|
|
650
750
|
datum_ids = []
|
|
651
|
-
for path in
|
|
751
|
+
for path in tag_paths:
|
|
652
752
|
tb = feather.read_table(path)
|
|
653
753
|
last_coord = path.name.split(".")[0]
|
|
654
754
|
tile_path = path.with_name(last_coord + ".datum_id.feather")
|
|
@@ -675,26 +775,14 @@ class AtlasMapTags:
|
|
|
675
775
|
return tag
|
|
676
776
|
raise ValueError(f"Tag {name} not found in projection {self.projection.id}.")
|
|
677
777
|
|
|
678
|
-
def _download_tag(self, tag_name: str, overwrite:
|
|
778
|
+
def _download_tag(self, tag_name: str, overwrite: bool = False):
|
|
679
779
|
"""
|
|
680
780
|
Downloads the feather tree for large sidecar columns.
|
|
681
781
|
"""
|
|
682
782
|
logger.info("Downloading tags")
|
|
683
|
-
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
684
|
-
root_url = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
|
|
685
|
-
|
|
686
783
|
tag = self._get_tag_by_name(tag_name)
|
|
687
784
|
tag_definition_id = tag["tag_definition_id"]
|
|
688
|
-
|
|
689
|
-
all_quads = list(self.projection._tiles_in_order(coords_only=True))
|
|
690
|
-
ordered_tag_paths = []
|
|
691
|
-
for quad in tqdm(all_quads):
|
|
692
|
-
quad_str = os.path.join(*[str(q) for q in quad])
|
|
693
|
-
filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
|
|
694
|
-
path = self.projection.tile_destination / Path(filename)
|
|
695
|
-
download_feather(root_url + filename, path, headers=self.dataset.header, overwrite=True)
|
|
696
|
-
ordered_tag_paths.append(path)
|
|
697
|
-
return ordered_tag_paths
|
|
785
|
+
return self.projection._download_sidecar(f"_tag.{tag_definition_id}", overwrite=overwrite)
|
|
698
786
|
|
|
699
787
|
def _remove_outdated_tag_files(self, tag_definition_ids: List[str]):
|
|
700
788
|
"""
|
|
@@ -705,14 +793,12 @@ class AtlasMapTags:
|
|
|
705
793
|
tag_definition_ids: A list of tag definition ids to keep.
|
|
706
794
|
"""
|
|
707
795
|
# NOTE: This currently only gets triggered on `df` property
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
quad_str = os.path.join(*[str(q) for q in quad])
|
|
711
|
-
tile = self.projection.tile_destination / Path(quad_str)
|
|
796
|
+
for key in self.projection._manifest["key"].to_pylist():
|
|
797
|
+
tile = self.projection.tile_destination / Path(key)
|
|
712
798
|
tile_dir = tile.parent
|
|
713
799
|
if tile_dir.exists():
|
|
714
|
-
|
|
715
|
-
for file in
|
|
800
|
+
tag_files = tile_dir.glob("*_tag*")
|
|
801
|
+
for file in tag_files:
|
|
716
802
|
tag_definition_id = file.name.split(".")[-2]
|
|
717
803
|
if tag_definition_id in tag_definition_ids:
|
|
718
804
|
try:
|
|
@@ -764,61 +850,69 @@ class AtlasMapData:
|
|
|
764
850
|
self.projection = projection
|
|
765
851
|
self.dataset = projection.dataset
|
|
766
852
|
self.id_field = self.projection.dataset.id_field
|
|
767
|
-
|
|
768
|
-
#
|
|
769
|
-
self.
|
|
770
|
-
|
|
771
|
-
|
|
853
|
+
if fields is None:
|
|
854
|
+
# TODO: fall back on something more reliable here
|
|
855
|
+
self.fields = self.dataset.dataset_fields
|
|
856
|
+
else:
|
|
857
|
+
for field in fields:
|
|
858
|
+
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
|
|
859
|
+
self.fields = fields
|
|
860
|
+
self._tb = None
|
|
772
861
|
|
|
773
|
-
|
|
774
|
-
|
|
862
|
+
def _load_data(self, data_columns: List[Tuple[str, str]]):
|
|
863
|
+
"""
|
|
864
|
+
Loads data from a list of data columns (field and sidecar name tuples).
|
|
775
865
|
|
|
776
|
-
|
|
866
|
+
Args:
|
|
867
|
+
data_columns: A list of tuples containing field name and sidecar name.
|
|
868
|
+
"""
|
|
777
869
|
tbs = []
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
870
|
+
|
|
871
|
+
sidecars_to_load = set([sidecar for _, sidecar in data_columns if sidecar != "datum_id"])
|
|
872
|
+
logger.info("Loading data")
|
|
873
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
874
|
+
# Use datum id as root table
|
|
875
|
+
tb = feather.read_table(
|
|
876
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
877
|
+
)
|
|
878
|
+
for sidecar in sidecars_to_load:
|
|
879
|
+
path = self.projection.tile_destination
|
|
880
|
+
if sidecar == "":
|
|
881
|
+
path = path / Path(key).with_suffix(".feather")
|
|
882
|
+
else:
|
|
883
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
884
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
785
885
|
for col in carfile.column_names:
|
|
786
|
-
|
|
886
|
+
if col in self.fields:
|
|
887
|
+
tb = tb.append_column(col, carfile[col])
|
|
787
888
|
tbs.append(tb)
|
|
788
|
-
self._tb = pa.concat_tables(tbs)
|
|
789
889
|
|
|
790
|
-
|
|
890
|
+
self._tb = pa.concat_tables(tbs)
|
|
791
891
|
|
|
792
|
-
def _download_data(self, fields=None):
|
|
892
|
+
def _download_data(self, fields: Optional[List[str]] = None) -> List[Tuple[str, str]]:
|
|
793
893
|
"""
|
|
794
|
-
Downloads the feather tree for
|
|
894
|
+
Downloads the feather tree for user uploaded data.
|
|
895
|
+
|
|
896
|
+
fields:
|
|
897
|
+
A list of fields to download. If None, downloads all fields.
|
|
898
|
+
|
|
899
|
+
Returns:
|
|
900
|
+
List of downloaded columns
|
|
795
901
|
"""
|
|
796
|
-
logger.info("Downloading
|
|
902
|
+
logger.info("Downloading data")
|
|
797
903
|
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
798
|
-
root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
|
|
799
|
-
|
|
800
|
-
all_quads = list(self.projection._tiles_in_order(coords_only=True))
|
|
801
|
-
sidecars = None
|
|
802
|
-
if fields is None:
|
|
803
|
-
fields = self.dataset.dataset_fields
|
|
804
|
-
else:
|
|
805
|
-
for field in fields:
|
|
806
|
-
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
|
|
807
904
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
905
|
+
# Download specified or all sidecar fields + always download datum_id
|
|
906
|
+
data_columns_to_load = [
|
|
907
|
+
(str(field), str(sidecar))
|
|
908
|
+
for field, sidecar in self.projection._registered_columns
|
|
909
|
+
if field[0] != "_" and ((field in fields) or sidecar == "datum_id")
|
|
812
910
|
]
|
|
813
911
|
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
path = self.projection.tile_destination / Path(filename)
|
|
819
|
-
# WARNING: Potentially large data request here
|
|
820
|
-
download_feather(root + filename, path, headers=self.dataset.header, overwrite=False)
|
|
821
|
-
return sidecars
|
|
912
|
+
# TODO: less confusing progress bar
|
|
913
|
+
for sidecar in set([sidecar for _, sidecar in data_columns_to_load]):
|
|
914
|
+
self.projection._download_sidecar(sidecar)
|
|
915
|
+
return data_columns_to_load
|
|
822
916
|
|
|
823
917
|
@property
|
|
824
918
|
def df(self) -> pd.DataFrame:
|
|
@@ -826,7 +920,8 @@ class AtlasMapData:
|
|
|
826
920
|
A pandas DataFrame associating each datapoint on your map to their metadata.
|
|
827
921
|
Converting to pandas DataFrame may materialize a large amount of data into memory.
|
|
828
922
|
"""
|
|
829
|
-
|
|
923
|
+
logger.warning("Converting to pandas dataframe. This may materialize a large amount of data into memory.")
|
|
924
|
+
return self.tb.to_pandas()
|
|
830
925
|
|
|
831
926
|
@property
|
|
832
927
|
def tb(self) -> pa.Table:
|
|
@@ -835,4 +930,9 @@ class AtlasMapData:
|
|
|
835
930
|
This table is memmapped from the underlying files and is the most efficient way to
|
|
836
931
|
access metadata information.
|
|
837
932
|
"""
|
|
933
|
+
if isinstance(self._tb, pa.Table):
|
|
934
|
+
return self._tb
|
|
935
|
+
|
|
936
|
+
columns = self._download_data(fields=self.fields)
|
|
937
|
+
self._load_data(columns)
|
|
838
938
|
return self._tb
|
|
@@ -4,28 +4,21 @@ import concurrent.futures
|
|
|
4
4
|
import io
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
-
import pickle
|
|
8
7
|
import time
|
|
9
|
-
import uuid
|
|
10
|
-
from collections import defaultdict
|
|
11
8
|
from contextlib import contextmanager
|
|
12
|
-
from datetime import
|
|
9
|
+
from datetime import datetime
|
|
13
10
|
from pathlib import Path
|
|
14
|
-
from typing import
|
|
11
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
15
12
|
|
|
16
13
|
import numpy as np
|
|
17
|
-
import pandas as pd
|
|
18
14
|
import pyarrow as pa
|
|
19
15
|
import requests
|
|
20
16
|
from loguru import logger
|
|
21
17
|
from pandas import DataFrame
|
|
22
18
|
from pyarrow import compute as pc
|
|
23
19
|
from pyarrow import feather, ipc
|
|
24
|
-
from pydantic import BaseModel, Field
|
|
25
20
|
from tqdm import tqdm
|
|
26
21
|
|
|
27
|
-
import nomic
|
|
28
|
-
|
|
29
22
|
from .cli import refresh_bearer_token, validate_api_http_response
|
|
30
23
|
from .data_inference import (
|
|
31
24
|
NomicDuplicatesOptions,
|
|
@@ -433,6 +426,8 @@ class AtlasProjection:
|
|
|
433
426
|
self._tile_data = None
|
|
434
427
|
self._data = None
|
|
435
428
|
self._schema = None
|
|
429
|
+
self._manifest_tb: Optional[pa.Table] = None
|
|
430
|
+
self._columns: List[Tuple[str, str]] = []
|
|
436
431
|
|
|
437
432
|
@property
|
|
438
433
|
def map_link(self):
|
|
@@ -590,130 +585,77 @@ class AtlasProjection:
|
|
|
590
585
|
self._schema = ipc.read_schema(io.BytesIO(content))
|
|
591
586
|
return self._schema
|
|
592
587
|
|
|
593
|
-
|
|
588
|
+
@property
|
|
589
|
+
def _registered_columns(self) -> List[Tuple[str, str]]:
|
|
594
590
|
"Returns [(field_name, sidecar_name), ...]"
|
|
595
|
-
|
|
591
|
+
if self._columns:
|
|
592
|
+
return self._columns
|
|
593
|
+
self._columns = []
|
|
596
594
|
for field in self.schema:
|
|
597
595
|
sidecar_name = json.loads(field.metadata.get(b"sidecar_name", b'""'))
|
|
598
|
-
if sidecar_name:
|
|
599
|
-
|
|
600
|
-
return
|
|
596
|
+
if sidecar_name is not None:
|
|
597
|
+
self._columns.append((field.name, sidecar_name))
|
|
598
|
+
return self._columns
|
|
601
599
|
|
|
602
|
-
|
|
600
|
+
@property
|
|
601
|
+
def _manifest(self) -> pa.Table:
|
|
603
602
|
"""
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
An Arrow table containing information for all data points in the index.
|
|
611
|
-
"""
|
|
612
|
-
if self._tile_data is not None:
|
|
613
|
-
return self._tile_data
|
|
614
|
-
logger.info(f"Downloading files for projection {self.projection_id}")
|
|
615
|
-
self._download_large_feather(overwrite=overwrite)
|
|
616
|
-
tbs = []
|
|
617
|
-
root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True)
|
|
618
|
-
try:
|
|
619
|
-
sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()])
|
|
620
|
-
except KeyError:
|
|
621
|
-
sidecars = set([])
|
|
622
|
-
sidecars |= set(sidecar_name for (_, sidecar_name) in self._registered_sidecars())
|
|
623
|
-
for path in self._tiles_in_order():
|
|
624
|
-
tb = pa.feather.read_table(path, memory_map=True) # type: ignore
|
|
625
|
-
for sidecar_file in sidecars:
|
|
626
|
-
carfile = pa.feather.read_table( # type: ignore
|
|
627
|
-
path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True
|
|
628
|
-
)
|
|
629
|
-
for col in carfile.column_names:
|
|
630
|
-
tb = tb.append_column(col, carfile[col])
|
|
631
|
-
tbs.append(tb)
|
|
632
|
-
self._tile_data = pa.concat_tables(tbs)
|
|
633
|
-
|
|
634
|
-
return self._tile_data
|
|
635
|
-
|
|
636
|
-
@overload
|
|
637
|
-
def _tiles_in_order(self, *, coords_only: Literal[False] = ...) -> Iterator[Path]: ...
|
|
603
|
+
Returns the tile manifest for the projection.
|
|
604
|
+
Tile manifest is in quadtree order. All quadtree operations should
|
|
605
|
+
depend on tile manifest to ensure consistency.
|
|
606
|
+
"""
|
|
607
|
+
if self._manifest_tb is not None:
|
|
608
|
+
return self._manifest_tb
|
|
638
609
|
|
|
639
|
-
|
|
640
|
-
|
|
610
|
+
manifest_path = self.tile_destination / "manifest.feather"
|
|
611
|
+
manifest_url = (
|
|
612
|
+
self.dataset.atlas_api_path
|
|
613
|
+
+ f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/manifest.feather"
|
|
614
|
+
)
|
|
641
615
|
|
|
642
|
-
|
|
643
|
-
|
|
616
|
+
download_feather(manifest_url, manifest_path, headers=self.dataset.header, overwrite=False)
|
|
617
|
+
self._manifest_tb = feather.read_table(manifest_path, memory_map=False)
|
|
618
|
+
return self._manifest_tb
|
|
644
619
|
|
|
645
|
-
def
|
|
620
|
+
def _get_sidecar_from_field(self, field: str) -> str:
|
|
646
621
|
"""
|
|
647
|
-
Returns
|
|
648
|
-
A list of all tiles in the projection in a fixed order so that all
|
|
649
|
-
datasets are guaranteed to be aligned.
|
|
650
|
-
"""
|
|
651
|
-
|
|
652
|
-
def children(z, x, y):
|
|
653
|
-
# This is the definition of a quadtree.
|
|
654
|
-
return [
|
|
655
|
-
(z + 1, x * 2, y * 2),
|
|
656
|
-
(z + 1, x * 2 + 1, y * 2),
|
|
657
|
-
(z + 1, x * 2, y * 2 + 1),
|
|
658
|
-
(z + 1, x * 2 + 1, y * 2 + 1),
|
|
659
|
-
]
|
|
660
|
-
|
|
661
|
-
# start with the root
|
|
662
|
-
paths = [(0, 0, 0)]
|
|
663
|
-
# Pop off the front, extend the back (breadth first traversal)
|
|
664
|
-
while len(paths) > 0:
|
|
665
|
-
z, x, y = paths.pop(0)
|
|
666
|
-
path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix(".feather")
|
|
667
|
-
if path.exists():
|
|
668
|
-
if coords_only:
|
|
669
|
-
yield (z, x, y)
|
|
670
|
-
else:
|
|
671
|
-
yield path
|
|
672
|
-
paths.extend(children(z, x, y)) # pyright: ignore
|
|
622
|
+
Returns the sidecar name for a given field.
|
|
673
623
|
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
624
|
+
Args:
|
|
625
|
+
field: the name of the field
|
|
626
|
+
"""
|
|
627
|
+
for f, sidecar in self._registered_columns:
|
|
628
|
+
if field == f:
|
|
629
|
+
return sidecar
|
|
630
|
+
raise ValueError(f"Field {field} not found in registered columns.")
|
|
677
631
|
|
|
678
|
-
def
|
|
632
|
+
def _download_sidecar(self, sidecar_name, overwrite: bool = False) -> List[Path]:
|
|
679
633
|
"""
|
|
680
|
-
Downloads the
|
|
634
|
+
Downloads sidecar files from the quadtree
|
|
681
635
|
Args:
|
|
636
|
+
sidecar_name: the name of the sidecar file
|
|
682
637
|
overwrite: if True then overwrite existing feather files.
|
|
683
638
|
|
|
684
639
|
Returns:
|
|
685
|
-
|
|
686
|
-
"""
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
elif sidecars is None:
|
|
705
|
-
sidecars = set()
|
|
706
|
-
if not "." in rawquad:
|
|
707
|
-
for sidecar in sidecars | registered_sidecars:
|
|
708
|
-
# The sidecar loses the feather suffix because it's supposed to be raw.
|
|
709
|
-
quads.append(quad.replace(".feather", f".{sidecar}"))
|
|
710
|
-
if not schema.metadata or b"children" not in schema.metadata:
|
|
711
|
-
# Sidecars don't have children.
|
|
712
|
-
continue
|
|
713
|
-
kids = schema.metadata.get(b"children")
|
|
714
|
-
children = json.loads(kids)
|
|
715
|
-
quads.extend(children)
|
|
716
|
-
return all_quads
|
|
640
|
+
List of downloaded feather files.
|
|
641
|
+
"""
|
|
642
|
+
downloaded_files = []
|
|
643
|
+
sidecar_suffix = "feather"
|
|
644
|
+
if sidecar_name != "":
|
|
645
|
+
sidecar_suffix = f"{sidecar_name}.feather"
|
|
646
|
+
for key in tqdm(self._manifest["key"].to_pylist()):
|
|
647
|
+
sidecar_path = self.tile_destination / f"{key}.{sidecar_suffix}"
|
|
648
|
+
sidecar_url = (
|
|
649
|
+
self.dataset.atlas_api_path
|
|
650
|
+
+ f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/{key}.{sidecar_suffix}"
|
|
651
|
+
)
|
|
652
|
+
download_feather(sidecar_url, sidecar_path, headers=self.dataset.header, overwrite=overwrite)
|
|
653
|
+
downloaded_files.append(sidecar_path)
|
|
654
|
+
return downloaded_files
|
|
655
|
+
|
|
656
|
+
@property
|
|
657
|
+
def tile_destination(self):
|
|
658
|
+
return Path("~/.nomic/cache", self.id).expanduser()
|
|
717
659
|
|
|
718
660
|
@property
|
|
719
661
|
def datum_id_field(self):
|
|
@@ -243,30 +243,48 @@ def get_object_size_in_bytes(obj):
|
|
|
243
243
|
# Helpful function for downloading feather files
|
|
244
244
|
# Best for small feather files
|
|
245
245
|
def download_feather(
|
|
246
|
-
url: Union[str, Path], path: Path, headers: Optional[dict] = None,
|
|
246
|
+
url: Union[str, Path], path: Path, headers: Optional[dict] = None, num_attempts=1, overwrite=False
|
|
247
247
|
) -> pa.Schema:
|
|
248
248
|
"""
|
|
249
249
|
Download a feather file from a URL to a local path.
|
|
250
250
|
Returns the schema of feather file if successful.
|
|
251
|
+
|
|
252
|
+
Parameters:
|
|
253
|
+
url (str): URL to download feather file from.
|
|
254
|
+
path (Path): Local path to save feather file to.
|
|
255
|
+
headers (dict): Optional headers to include in request.
|
|
256
|
+
num_attempts (int): Number of download attempts before raising an error.
|
|
257
|
+
overwrite (bool): Whether to overwrite existing file.
|
|
258
|
+
Returns:
|
|
259
|
+
Feather schema.
|
|
251
260
|
"""
|
|
252
|
-
assert
|
|
261
|
+
assert num_attempts > 0, "Num attempts must be greater than 0"
|
|
253
262
|
download_attempt = 0
|
|
254
263
|
download_success = False
|
|
255
264
|
schema = None
|
|
256
|
-
while download_attempt <
|
|
265
|
+
while download_attempt < num_attempts and not download_success:
|
|
257
266
|
download_attempt += 1
|
|
258
267
|
if not path.exists() or overwrite:
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
268
|
+
# Attempt download
|
|
269
|
+
try:
|
|
270
|
+
data = requests.get(str(url), headers=headers)
|
|
271
|
+
readable = BytesIO(data.content)
|
|
272
|
+
readable.seek(0)
|
|
273
|
+
tb = pa.feather.read_table(readable, memory_map=False) # type: ignore
|
|
274
|
+
schema = tb.schema
|
|
275
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
276
|
+
pa.feather.write_feather(tb, path) # type: ignore
|
|
277
|
+
download_success = True
|
|
278
|
+
except pa.ArrowInvalid:
|
|
279
|
+
# failed try again
|
|
280
|
+
path.unlink(missing_ok=True)
|
|
281
|
+
else:
|
|
282
|
+
# Load existing file
|
|
283
|
+
try:
|
|
284
|
+
schema = ipc.open_file(path).schema
|
|
285
|
+
download_success = True
|
|
286
|
+
except pa.ArrowInvalid:
|
|
287
|
+
path.unlink(missing_ok=True)
|
|
270
288
|
if not download_success or schema is None:
|
|
271
|
-
raise ValueError(f"Failed to download feather file from {url} after {
|
|
289
|
+
raise ValueError(f"Failed to download feather file from {url} after {num_attempts} attempts.")
|
|
272
290
|
return schema
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|