nomic 3.0.31__tar.gz → 3.0.33__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nomic might be problematic. Click here for more details.
- {nomic-3.0.31 → nomic-3.0.33}/PKG-INFO +1 -1
- {nomic-3.0.31 → nomic-3.0.33}/nomic/aws/sagemaker.py +0 -1
- {nomic-3.0.31 → nomic-3.0.33}/nomic/data_inference.py +2 -2
- {nomic-3.0.31 → nomic-3.0.33}/nomic/data_operations.py +250 -197
- {nomic-3.0.31 → nomic-3.0.33}/nomic/dataset.py +62 -137
- {nomic-3.0.31 → nomic-3.0.33}/nomic/utils.py +48 -8
- {nomic-3.0.31 → nomic-3.0.33}/nomic.egg-info/PKG-INFO +1 -1
- {nomic-3.0.31 → nomic-3.0.33}/nomic.egg-info/requires.txt +1 -1
- {nomic-3.0.31 → nomic-3.0.33}/setup.py +2 -2
- {nomic-3.0.31 → nomic-3.0.33}/README.md +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/__init__.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/atlas.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/aws/__init__.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/cli.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/embed.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/pl_callbacks/__init__.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/pl_callbacks/pl_callback.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic/settings.py +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic.egg-info/SOURCES.txt +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic.egg-info/dependency_links.txt +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic.egg-info/entry_points.txt +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/nomic.egg-info/top_level.txt +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/pyproject.toml +0 -0
- {nomic-3.0.31 → nomic-3.0.33}/setup.cfg +0 -0
|
@@ -84,11 +84,11 @@ class NomicTopicOptions(BaseModel):
|
|
|
84
84
|
|
|
85
85
|
Args:
|
|
86
86
|
build_topic_model: If True, builds a topic model over your dataset's embeddings.
|
|
87
|
-
|
|
87
|
+
topic_label_field: The dataset column (usually the column you embedded) that Atlas will use to assign a human-readable description to each topic.
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
build_topic_model: bool = True
|
|
91
|
-
|
|
91
|
+
topic_label_field: Optional[str] = Field(default=None, alias="community_description_target_field")
|
|
92
92
|
cluster_method: str = "fast"
|
|
93
93
|
enforce_topic_hierarchy: bool = False
|
|
94
94
|
|
|
@@ -1,15 +1,10 @@
|
|
|
1
1
|
import base64
|
|
2
|
-
import concurrent
|
|
3
|
-
import concurrent.futures
|
|
4
|
-
import glob
|
|
5
2
|
import io
|
|
6
3
|
import json
|
|
7
|
-
import os
|
|
8
4
|
from collections import defaultdict
|
|
9
5
|
from datetime import datetime
|
|
10
|
-
from io import BytesIO
|
|
11
6
|
from pathlib import Path
|
|
12
|
-
from typing import Dict, Iterable, List, Optional, Tuple
|
|
7
|
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
|
13
8
|
|
|
14
9
|
import numpy as np
|
|
15
10
|
import pandas as pd
|
|
@@ -17,12 +12,9 @@ import pyarrow as pa
|
|
|
17
12
|
import requests
|
|
18
13
|
from loguru import logger
|
|
19
14
|
from pyarrow import compute as pc
|
|
20
|
-
from pyarrow import feather
|
|
15
|
+
from pyarrow import feather
|
|
21
16
|
from tqdm import tqdm
|
|
22
17
|
|
|
23
|
-
from .settings import EMBEDDING_PAGINATION_LIMIT
|
|
24
|
-
from .utils import download_feather
|
|
25
|
-
|
|
26
18
|
|
|
27
19
|
class AtlasMapDuplicates:
|
|
28
20
|
"""
|
|
@@ -34,22 +26,57 @@ class AtlasMapDuplicates:
|
|
|
34
26
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
35
27
|
self.projection = projection
|
|
36
28
|
self.id_field = self.projection.dataset.id_field
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
self.
|
|
45
|
-
|
|
46
|
-
|
|
29
|
+
|
|
30
|
+
duplicate_columns = [
|
|
31
|
+
(field, sidecar)
|
|
32
|
+
for field, sidecar in self.projection._registered_columns
|
|
33
|
+
if field.startswith("_duplicate_class")
|
|
34
|
+
]
|
|
35
|
+
cluster_columns = [
|
|
36
|
+
(field, sidecar) for field, sidecar in self.projection._registered_columns if field.startswith("_cluster")
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
assert len(duplicate_columns) > 0, "Duplicate detection has not yet been run on this map."
|
|
40
|
+
|
|
41
|
+
self._duplicate_column = duplicate_columns[0]
|
|
42
|
+
self._cluster_column = cluster_columns[0]
|
|
43
|
+
self._tb = None
|
|
44
|
+
|
|
45
|
+
def _load_duplicates(self):
|
|
46
|
+
"""
|
|
47
|
+
Loads duplicates from the feather tree.
|
|
48
|
+
"""
|
|
49
|
+
tbs = []
|
|
50
|
+
duplicate_sidecar = self._duplicate_column[1]
|
|
51
|
+
self.duplicate_field = self._duplicate_column[0].lstrip("_")
|
|
52
|
+
self.cluster_field = self._cluster_column[0].lstrip("_")
|
|
53
|
+
logger.info("Loading duplicates")
|
|
54
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
55
|
+
# Use datum id as root table
|
|
56
|
+
tb = feather.read_table(
|
|
57
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
47
58
|
)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
59
|
+
path = self.projection.tile_destination
|
|
60
|
+
|
|
61
|
+
if duplicate_sidecar == "":
|
|
62
|
+
path = path / Path(key).with_suffix(".feather")
|
|
63
|
+
else:
|
|
64
|
+
path = path / Path(key).with_suffix(f".{duplicate_sidecar}.feather")
|
|
65
|
+
|
|
66
|
+
duplicate_tb = feather.read_table(path, memory_map=True)
|
|
67
|
+
for field in (self._duplicate_column[0], self._cluster_column[0]):
|
|
68
|
+
tb = tb.append_column(field, duplicate_tb[field])
|
|
69
|
+
tbs.append(tb)
|
|
70
|
+
self._tb = pa.concat_tables(tbs).rename_columns([self.id_field, self.duplicate_field, self.cluster_field])
|
|
71
|
+
|
|
72
|
+
def _download_duplicates(self):
|
|
73
|
+
"""
|
|
74
|
+
Downloads the feather tree for duplicates.
|
|
75
|
+
"""
|
|
76
|
+
logger.info("Downloading duplicates")
|
|
77
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
78
|
+
assert self._cluster_column[1] == self._duplicate_column[1], "Cluster and duplicate should be in same sidecar"
|
|
79
|
+
self.projection._download_sidecar(self._duplicate_column[1], overwrite=False)
|
|
53
80
|
|
|
54
81
|
@property
|
|
55
82
|
def df(self) -> pd.DataFrame:
|
|
@@ -65,6 +92,10 @@ class AtlasMapDuplicates:
|
|
|
65
92
|
This table is memmapped from the underlying files and is the most efficient way to
|
|
66
93
|
access duplicate information.
|
|
67
94
|
"""
|
|
95
|
+
if isinstance(self._tb, pa.Table):
|
|
96
|
+
return self._tb
|
|
97
|
+
self._download_duplicates()
|
|
98
|
+
self._load_duplicates()
|
|
68
99
|
return self._tb
|
|
69
100
|
|
|
70
101
|
def deletion_candidates(self) -> List[str]:
|
|
@@ -97,34 +128,69 @@ class AtlasMapTopics:
|
|
|
97
128
|
self.id_field = self.projection.dataset.id_field
|
|
98
129
|
self._metadata = None
|
|
99
130
|
self._hierarchy = None
|
|
131
|
+
self._topic_columns = [
|
|
132
|
+
column for column in self.projection._registered_columns if column[0].startswith("_topic_depth_")
|
|
133
|
+
]
|
|
134
|
+
assert len(self._topic_columns) > 0, "Topic modeling has not yet been run on this map."
|
|
135
|
+
self.depth = len(self._topic_columns)
|
|
136
|
+
self._tb = None
|
|
137
|
+
|
|
138
|
+
def _load_topics(self):
|
|
139
|
+
"""
|
|
140
|
+
Loads topics from the feather tree.
|
|
141
|
+
"""
|
|
142
|
+
integer_topics = False
|
|
143
|
+
# pd.Series to match pd typing
|
|
144
|
+
label_df: Optional[Union[pd.DataFrame, pd.Series]] = None
|
|
145
|
+
if "int" in self._topic_columns[0][0]:
|
|
146
|
+
integer_topics = True
|
|
147
|
+
label_df = self.metadata[["topic_id", "depth", "topic_short_description"]]
|
|
148
|
+
tbs = []
|
|
149
|
+
# Should just be one sidecar
|
|
150
|
+
topic_sidecar = set([sidecar for _, sidecar in self._topic_columns]).pop()
|
|
151
|
+
logger.info("Loading topics")
|
|
152
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
153
|
+
# Use datum id as root table
|
|
154
|
+
tb = feather.read_table(
|
|
155
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
156
|
+
)
|
|
157
|
+
path = self.projection.tile_destination
|
|
158
|
+
if topic_sidecar == "":
|
|
159
|
+
path = path / Path(key).with_suffix(".feather")
|
|
160
|
+
else:
|
|
161
|
+
path = path / Path(key).with_suffix(f".{topic_sidecar}.feather")
|
|
100
162
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
# If using topic ids, fetch topic labels
|
|
107
|
-
if "int" in topic_fields[0]:
|
|
108
|
-
new_topic_fields = []
|
|
109
|
-
label_df = self.metadata[["topic_id", "depth", "topic_short_description"]]
|
|
110
|
-
for d in range(1, self.depth + 1):
|
|
163
|
+
topic_tb = feather.read_table(path, memory_map=True)
|
|
164
|
+
# Do this in depth order
|
|
165
|
+
for d in range(1, self.depth + 1):
|
|
166
|
+
column = f"_topic_depth_{d}"
|
|
167
|
+
if integer_topics:
|
|
111
168
|
column = f"_topic_depth_{d}_int"
|
|
112
|
-
topic_ids_to_label =
|
|
169
|
+
topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
|
|
170
|
+
assert label_df is not None
|
|
113
171
|
topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
|
|
114
172
|
topic_ids_to_label, on="topic_id", how="right"
|
|
115
173
|
)
|
|
116
174
|
new_column = f"_topic_depth_{d}"
|
|
117
|
-
|
|
175
|
+
tb = tb.append_column(
|
|
118
176
|
new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
|
|
119
177
|
)
|
|
120
|
-
|
|
121
|
-
|
|
178
|
+
else:
|
|
179
|
+
tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
|
|
180
|
+
tbs.append(tb)
|
|
122
181
|
|
|
123
|
-
|
|
124
|
-
|
|
182
|
+
renamed_columns = [self.id_field] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
|
|
183
|
+
self._tb = pa.concat_tables(tbs).rename_columns(renamed_columns)
|
|
125
184
|
|
|
126
|
-
|
|
127
|
-
|
|
185
|
+
def _download_topics(self):
|
|
186
|
+
"""
|
|
187
|
+
Downloads the feather tree for topics.
|
|
188
|
+
"""
|
|
189
|
+
logger.info("Downloading topics")
|
|
190
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
191
|
+
topic_sidecars = set([sidecar for _, sidecar in self._topic_columns])
|
|
192
|
+
assert len(topic_sidecars) == 1, "Multiple topic sidecars found."
|
|
193
|
+
self.projection._download_sidecar(topic_sidecars.pop(), overwrite=False)
|
|
128
194
|
|
|
129
195
|
@property
|
|
130
196
|
def df(self) -> pd.DataFrame:
|
|
@@ -140,6 +206,10 @@ class AtlasMapTopics:
|
|
|
140
206
|
This table is memmapped from the underlying files and is the most efficient way to
|
|
141
207
|
access topic information.
|
|
142
208
|
"""
|
|
209
|
+
if isinstance(self._tb, pa.Table):
|
|
210
|
+
return self._tb
|
|
211
|
+
self._download_topics()
|
|
212
|
+
self._load_topics()
|
|
143
213
|
return self._tb
|
|
144
214
|
|
|
145
215
|
@property
|
|
@@ -263,8 +333,8 @@ class AtlasMapTopics:
|
|
|
263
333
|
A list of `{topic, count}` dictionaries, sorted from largest count to smallest count.
|
|
264
334
|
"""
|
|
265
335
|
data = AtlasMapData(self.projection, fields=[time_field])
|
|
266
|
-
time_data = data.
|
|
267
|
-
merged_tb = self.
|
|
336
|
+
time_data = data.tb.select([self.id_field, time_field])
|
|
337
|
+
merged_tb = self.tb.join(time_data, self.id_field, join_type="inner").combine_chunks()
|
|
268
338
|
|
|
269
339
|
del time_data # free up memory
|
|
270
340
|
|
|
@@ -379,8 +449,8 @@ class AtlasMapEmbeddings:
|
|
|
379
449
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
380
450
|
self.projection = projection
|
|
381
451
|
self.id_field = self.projection.dataset.id_field
|
|
382
|
-
self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, "x", "y"])
|
|
383
452
|
self.dataset = projection.dataset
|
|
453
|
+
self._tb: pa.Table = None
|
|
384
454
|
self._latent = None
|
|
385
455
|
|
|
386
456
|
@property
|
|
@@ -401,6 +471,31 @@ class AtlasMapEmbeddings:
|
|
|
401
471
|
|
|
402
472
|
Does not include high-dimensional embeddings.
|
|
403
473
|
"""
|
|
474
|
+
if isinstance(self._tb, pa.Table):
|
|
475
|
+
return self._tb
|
|
476
|
+
|
|
477
|
+
self._download_projected()
|
|
478
|
+
|
|
479
|
+
logger.info("Loading projected embeddings")
|
|
480
|
+
|
|
481
|
+
tbs = []
|
|
482
|
+
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
483
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
484
|
+
# Use datum id as root table
|
|
485
|
+
tb = feather.read_table(
|
|
486
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
487
|
+
)
|
|
488
|
+
path = self.projection.tile_destination
|
|
489
|
+
if coord_sidecar == "":
|
|
490
|
+
path = path / Path(key).with_suffix(".feather")
|
|
491
|
+
else:
|
|
492
|
+
path = path / Path(key).with_suffix(f".{coord_sidecar}.feather")
|
|
493
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
494
|
+
for col in carfile.column_names:
|
|
495
|
+
if col in ["x", "y"]:
|
|
496
|
+
tb = tb.append_column(col, carfile[col])
|
|
497
|
+
tbs.append(tb)
|
|
498
|
+
self._tb = pa.concat_tables(tbs)
|
|
404
499
|
return self._tb
|
|
405
500
|
|
|
406
501
|
@property
|
|
@@ -426,53 +521,43 @@ class AtlasMapEmbeddings:
|
|
|
426
521
|
if self._latent is not None:
|
|
427
522
|
return self._latent
|
|
428
523
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
if not root_embedding.exists():
|
|
432
|
-
self._download_latent()
|
|
524
|
+
downloaded_files_in_tile_order = self._download_latent()
|
|
525
|
+
assert len(downloaded_files_in_tile_order) > 0, "No embeddings found for this map."
|
|
433
526
|
all_embeddings = []
|
|
434
|
-
|
|
435
|
-
for path in
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1]))
|
|
440
|
-
if len(sortable) == 0:
|
|
441
|
-
raise FileNotFoundError(
|
|
442
|
-
"Could not find any embeddings for tile {}".format(path)
|
|
443
|
-
+ " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'."
|
|
444
|
-
)
|
|
445
|
-
for file in sortable:
|
|
446
|
-
tb = feather.read_table(file, memory_map=True)
|
|
447
|
-
dims = tb["_embeddings"].type.list_size
|
|
448
|
-
all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
|
|
527
|
+
logger.info("Loading latent embeddings")
|
|
528
|
+
for path in tqdm(downloaded_files_in_tile_order):
|
|
529
|
+
tb = feather.read_table(path, memory_map=True)
|
|
530
|
+
dims = tb["_embeddings"].type.list_size
|
|
531
|
+
all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
|
|
449
532
|
return np.vstack(all_embeddings)
|
|
450
533
|
|
|
451
|
-
def
|
|
534
|
+
def _download_projected(self) -> List[Path]:
|
|
452
535
|
"""
|
|
453
|
-
Downloads the
|
|
536
|
+
Downloads the feather tree for projection coordinates.
|
|
454
537
|
"""
|
|
455
|
-
logger.
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
538
|
+
logger.info("Downloading projected embeddings")
|
|
539
|
+
# Note that y coord should be in same sidecar
|
|
540
|
+
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
541
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
542
|
+
return self.projection._download_sidecar(coord_sidecar, overwrite=False)
|
|
459
543
|
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
544
|
+
def _download_latent(self) -> List[Path]:
|
|
545
|
+
"""
|
|
546
|
+
Downloads the feather tree for embeddings.
|
|
547
|
+
Returns the path to downloaded embeddings.
|
|
548
|
+
"""
|
|
549
|
+
# TODO: Is size of the embedding files (several hundreds of MBs) going to be a problem here?
|
|
550
|
+
logger.info("Downloading latent embeddings")
|
|
551
|
+
embedding_sidecar = None
|
|
552
|
+
for field, sidecar in self.projection._registered_columns:
|
|
553
|
+
# NOTE: be _embeddings or _embedding
|
|
554
|
+
if field == "_embeddings":
|
|
555
|
+
embedding_sidecar = sidecar
|
|
556
|
+
break
|
|
469
557
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
feather.write_feather(tb, dest)
|
|
474
|
-
last = tilename
|
|
475
|
-
pbar.update(1)
|
|
558
|
+
if embedding_sidecar is None:
|
|
559
|
+
raise ValueError("No embeddings found for this map.")
|
|
560
|
+
return self.projection._download_sidecar(embedding_sidecar, overwrite=False)
|
|
476
561
|
|
|
477
562
|
def vector_search(
|
|
478
563
|
self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5
|
|
@@ -586,12 +671,15 @@ class AtlasMapTags:
|
|
|
586
671
|
self.projection = projection
|
|
587
672
|
self.dataset = projection.dataset
|
|
588
673
|
self.id_field = self.projection.dataset.id_field
|
|
589
|
-
# Pre-fetch
|
|
590
|
-
|
|
674
|
+
# Pre-fetch datum ids first upon initialization
|
|
675
|
+
try:
|
|
676
|
+
self.projection._download_sidecar("datum_id")
|
|
677
|
+
except Exception:
|
|
678
|
+
raise ValueError("Failed to fetch datum ids which is required to load tags.")
|
|
591
679
|
self.auto_cleanup = auto_cleanup
|
|
592
680
|
|
|
593
681
|
@property
|
|
594
|
-
def df(self, overwrite:
|
|
682
|
+
def df(self, overwrite: bool = False) -> pd.DataFrame:
|
|
595
683
|
"""
|
|
596
684
|
Pandas DataFrame mapping each data point to its tags.
|
|
597
685
|
"""
|
|
@@ -602,16 +690,13 @@ class AtlasMapTags:
|
|
|
602
690
|
for tag in tags:
|
|
603
691
|
self._download_tag(tag["tag_name"], overwrite=overwrite)
|
|
604
692
|
tbs = []
|
|
605
|
-
|
|
606
|
-
for
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
path = self.projection.tile_destination / Path(datum_id_filename)
|
|
610
|
-
tb = feather.read_table(path, memory_map=True)
|
|
693
|
+
logger.info("Loading tags")
|
|
694
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
695
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
696
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
611
697
|
for tag in tags:
|
|
612
698
|
tag_definition_id = tag["tag_definition_id"]
|
|
613
|
-
|
|
614
|
-
path = self.projection.tile_destination / Path(tag_filename)
|
|
699
|
+
path = self.projection.tile_destination / Path(key).with_suffix(f"._tag.{tag_definition_id}.feather")
|
|
615
700
|
tag_tb = feather.read_table(path, memory_map=True)
|
|
616
701
|
bitmask = None
|
|
617
702
|
if "all_set" in tag_tb.column_names:
|
|
@@ -650,7 +735,7 @@ class AtlasMapTags:
|
|
|
650
735
|
keep_tags.append(tag)
|
|
651
736
|
return keep_tags
|
|
652
737
|
|
|
653
|
-
def get_datums_in_tag(self, tag_name: str, overwrite:
|
|
738
|
+
def get_datums_in_tag(self, tag_name: str, overwrite: bool = False):
|
|
654
739
|
"""
|
|
655
740
|
Returns the datum ids in a given tag.
|
|
656
741
|
|
|
@@ -661,9 +746,9 @@ class AtlasMapTags:
|
|
|
661
746
|
Returns:
|
|
662
747
|
List of datum ids.
|
|
663
748
|
"""
|
|
664
|
-
|
|
749
|
+
tag_paths = self._download_tag(tag_name, overwrite=overwrite)
|
|
665
750
|
datum_ids = []
|
|
666
|
-
for path in
|
|
751
|
+
for path in tag_paths:
|
|
667
752
|
tb = feather.read_table(path)
|
|
668
753
|
last_coord = path.name.split(".")[0]
|
|
669
754
|
tile_path = path.with_name(last_coord + ".datum_id.feather")
|
|
@@ -690,38 +775,14 @@ class AtlasMapTags:
|
|
|
690
775
|
return tag
|
|
691
776
|
raise ValueError(f"Tag {name} not found in projection {self.projection.id}.")
|
|
692
777
|
|
|
693
|
-
def _download_tag(self, tag_name: str, overwrite:
|
|
778
|
+
def _download_tag(self, tag_name: str, overwrite: bool = False):
|
|
694
779
|
"""
|
|
695
780
|
Downloads the feather tree for large sidecar columns.
|
|
696
781
|
"""
|
|
697
|
-
|
|
698
|
-
root_url = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
|
|
699
|
-
|
|
782
|
+
logger.info("Downloading tags")
|
|
700
783
|
tag = self._get_tag_by_name(tag_name)
|
|
701
784
|
tag_definition_id = tag["tag_definition_id"]
|
|
702
|
-
|
|
703
|
-
all_quads = list(self.projection._tiles_in_order(coords_only=True))
|
|
704
|
-
ordered_tag_paths = []
|
|
705
|
-
for quad in tqdm(all_quads):
|
|
706
|
-
quad_str = os.path.join(*[str(q) for q in quad])
|
|
707
|
-
filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
|
|
708
|
-
path = self.projection.tile_destination / Path(filename)
|
|
709
|
-
download_attempt = 0
|
|
710
|
-
download_success = False
|
|
711
|
-
while download_attempt < 3 and not download_success:
|
|
712
|
-
download_attempt += 1
|
|
713
|
-
if not path.exists() or overwrite:
|
|
714
|
-
download_feather(root_url + filename, path, headers=self.dataset.header)
|
|
715
|
-
try:
|
|
716
|
-
ipc.open_file(path).schema
|
|
717
|
-
download_success = True
|
|
718
|
-
except pa.ArrowInvalid:
|
|
719
|
-
path.unlink(missing_ok=True)
|
|
720
|
-
|
|
721
|
-
if not download_success:
|
|
722
|
-
raise Exception(f"Failed to download tag {tag_name}.")
|
|
723
|
-
ordered_tag_paths.append(path)
|
|
724
|
-
return ordered_tag_paths
|
|
785
|
+
return self.projection._download_sidecar(f"_tag.{tag_definition_id}", overwrite=overwrite)
|
|
725
786
|
|
|
726
787
|
def _remove_outdated_tag_files(self, tag_definition_ids: List[str]):
|
|
727
788
|
"""
|
|
@@ -732,14 +793,12 @@ class AtlasMapTags:
|
|
|
732
793
|
tag_definition_ids: A list of tag definition ids to keep.
|
|
733
794
|
"""
|
|
734
795
|
# NOTE: This currently only gets triggered on `df` property
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
quad_str = os.path.join(*[str(q) for q in quad])
|
|
738
|
-
tile = self.projection.tile_destination / Path(quad_str)
|
|
796
|
+
for key in self.projection._manifest["key"].to_pylist():
|
|
797
|
+
tile = self.projection.tile_destination / Path(key)
|
|
739
798
|
tile_dir = tile.parent
|
|
740
799
|
if tile_dir.exists():
|
|
741
|
-
|
|
742
|
-
for file in
|
|
800
|
+
tag_files = tile_dir.glob("*_tag*")
|
|
801
|
+
for file in tag_files:
|
|
743
802
|
tag_definition_id = file.name.split(".")[-2]
|
|
744
803
|
if tag_definition_id in tag_definition_ids:
|
|
745
804
|
try:
|
|
@@ -791,81 +850,69 @@ class AtlasMapData:
|
|
|
791
850
|
self.projection = projection
|
|
792
851
|
self.dataset = projection.dataset
|
|
793
852
|
self.id_field = self.projection.dataset.id_field
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
853
|
+
if fields is None:
|
|
854
|
+
# TODO: fall back on something more reliable here
|
|
855
|
+
self.fields = self.dataset.dataset_fields
|
|
856
|
+
else:
|
|
857
|
+
for field in fields:
|
|
858
|
+
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
|
|
859
|
+
self.fields = fields
|
|
860
|
+
self._tb = None
|
|
800
861
|
|
|
801
|
-
|
|
802
|
-
|
|
862
|
+
def _load_data(self, data_columns: List[Tuple[str, str]]):
|
|
863
|
+
"""
|
|
864
|
+
Loads data from a list of data columns (field and sidecar name tuples).
|
|
803
865
|
|
|
804
|
-
|
|
866
|
+
Args:
|
|
867
|
+
data_columns: A list of tuples containing field name and sidecar name.
|
|
868
|
+
"""
|
|
805
869
|
tbs = []
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
fname = (
|
|
822
|
-
base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8")
|
|
823
|
-
if big_sidecar != "datum_id"
|
|
824
|
-
else big_sidecar
|
|
825
|
-
)
|
|
826
|
-
carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) # type: ignore
|
|
870
|
+
|
|
871
|
+
sidecars_to_load = set([sidecar for _, sidecar in data_columns if sidecar != "datum_id"])
|
|
872
|
+
logger.info("Loading data")
|
|
873
|
+
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
874
|
+
# Use datum id as root table
|
|
875
|
+
tb = feather.read_table(
|
|
876
|
+
self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
|
|
877
|
+
)
|
|
878
|
+
for sidecar in sidecars_to_load:
|
|
879
|
+
path = self.projection.tile_destination
|
|
880
|
+
if sidecar == "":
|
|
881
|
+
path = path / Path(key).with_suffix(".feather")
|
|
882
|
+
else:
|
|
883
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
884
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
827
885
|
for col in carfile.column_names:
|
|
828
|
-
|
|
886
|
+
if col in self.fields:
|
|
887
|
+
tb = tb.append_column(col, carfile[col])
|
|
829
888
|
tbs.append(tb)
|
|
830
|
-
self._tb = pa.concat_tables(tbs)
|
|
831
889
|
|
|
832
|
-
|
|
890
|
+
self._tb = pa.concat_tables(tbs)
|
|
833
891
|
|
|
834
|
-
def _download_data(self, fields=None):
|
|
892
|
+
def _download_data(self, fields: Optional[List[str]] = None) -> List[Tuple[str, str]]:
|
|
835
893
|
"""
|
|
836
|
-
Downloads the feather tree for
|
|
894
|
+
Downloads the feather tree for user uploaded data.
|
|
895
|
+
|
|
896
|
+
fields:
|
|
897
|
+
A list of fields to download. If None, downloads all fields.
|
|
898
|
+
|
|
899
|
+
Returns:
|
|
900
|
+
List of downloaded columns
|
|
837
901
|
"""
|
|
902
|
+
logger.info("Downloading data")
|
|
838
903
|
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
839
|
-
root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
|
|
840
|
-
|
|
841
|
-
all_quads = list(self.projection._tiles_in_order(coords_only=True))
|
|
842
|
-
sidecars = fields
|
|
843
|
-
registered_sidecars = self.projection._registered_sidecars()
|
|
844
|
-
if sidecars is None:
|
|
845
|
-
sidecars = [
|
|
846
|
-
field
|
|
847
|
-
for field in self.dataset.dataset_fields
|
|
848
|
-
if field not in self._basic_data.column_names and field != "_embeddings"
|
|
849
|
-
]
|
|
850
|
-
else:
|
|
851
|
-
for field in sidecars:
|
|
852
|
-
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
|
|
853
|
-
encoded_sidecars = [base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") for sidecar in sidecars]
|
|
854
|
-
if any(sidecar == "datum_id" for (field, sidecar) in registered_sidecars):
|
|
855
|
-
sidecars.append("datum_id")
|
|
856
|
-
encoded_sidecars.append("datum_id")
|
|
857
|
-
|
|
858
|
-
for quad in tqdm(all_quads):
|
|
859
|
-
for encoded_colname in encoded_sidecars:
|
|
860
|
-
quad_str = os.path.join(*[str(q) for q in quad])
|
|
861
|
-
filename = quad_str + "." + encoded_colname + ".feather"
|
|
862
|
-
path = self.projection.tile_destination / Path(filename)
|
|
863
904
|
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
905
|
+
# Download specified or all sidecar fields + always download datum_id
|
|
906
|
+
data_columns_to_load = [
|
|
907
|
+
(str(field), str(sidecar))
|
|
908
|
+
for field, sidecar in self.projection._registered_columns
|
|
909
|
+
if field[0] != "_" and ((field in fields) or sidecar == "datum_id")
|
|
910
|
+
]
|
|
867
911
|
|
|
868
|
-
|
|
912
|
+
# TODO: less confusing progress bar
|
|
913
|
+
for sidecar in set([sidecar for _, sidecar in data_columns_to_load]):
|
|
914
|
+
self.projection._download_sidecar(sidecar)
|
|
915
|
+
return data_columns_to_load
|
|
869
916
|
|
|
870
917
|
@property
|
|
871
918
|
def df(self) -> pd.DataFrame:
|
|
@@ -873,7 +920,8 @@ class AtlasMapData:
|
|
|
873
920
|
A pandas DataFrame associating each datapoint on your map to their metadata.
|
|
874
921
|
Converting to pandas DataFrame may materialize a large amount of data into memory.
|
|
875
922
|
"""
|
|
876
|
-
|
|
923
|
+
logger.warning("Converting to pandas dataframe. This may materialize a large amount of data into memory.")
|
|
924
|
+
return self.tb.to_pandas()
|
|
877
925
|
|
|
878
926
|
@property
|
|
879
927
|
def tb(self) -> pa.Table:
|
|
@@ -882,4 +930,9 @@ class AtlasMapData:
|
|
|
882
930
|
This table is memmapped from the underlying files and is the most efficient way to
|
|
883
931
|
access metadata information.
|
|
884
932
|
"""
|
|
933
|
+
if isinstance(self._tb, pa.Table):
|
|
934
|
+
return self._tb
|
|
935
|
+
|
|
936
|
+
columns = self._download_data(fields=self.fields)
|
|
937
|
+
self._load_data(columns)
|
|
885
938
|
return self._tb
|
|
@@ -4,28 +4,21 @@ import concurrent.futures
|
|
|
4
4
|
import io
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
-
import pickle
|
|
8
7
|
import time
|
|
9
|
-
import uuid
|
|
10
|
-
from collections import defaultdict
|
|
11
8
|
from contextlib import contextmanager
|
|
12
|
-
from datetime import
|
|
9
|
+
from datetime import datetime
|
|
13
10
|
from pathlib import Path
|
|
14
|
-
from typing import
|
|
11
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
15
12
|
|
|
16
13
|
import numpy as np
|
|
17
|
-
import pandas as pd
|
|
18
14
|
import pyarrow as pa
|
|
19
15
|
import requests
|
|
20
16
|
from loguru import logger
|
|
21
17
|
from pandas import DataFrame
|
|
22
18
|
from pyarrow import compute as pc
|
|
23
19
|
from pyarrow import feather, ipc
|
|
24
|
-
from pydantic import BaseModel, Field
|
|
25
20
|
from tqdm import tqdm
|
|
26
21
|
|
|
27
|
-
import nomic
|
|
28
|
-
|
|
29
22
|
from .cli import refresh_bearer_token, validate_api_http_response
|
|
30
23
|
from .data_inference import (
|
|
31
24
|
NomicDuplicatesOptions,
|
|
@@ -36,7 +29,7 @@ from .data_inference import (
|
|
|
36
29
|
)
|
|
37
30
|
from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics
|
|
38
31
|
from .settings import *
|
|
39
|
-
from .utils import assert_valid_project_id,
|
|
32
|
+
from .utils import assert_valid_project_id, download_feather
|
|
40
33
|
|
|
41
34
|
|
|
42
35
|
class AtlasUser:
|
|
@@ -433,6 +426,8 @@ class AtlasProjection:
|
|
|
433
426
|
self._tile_data = None
|
|
434
427
|
self._data = None
|
|
435
428
|
self._schema = None
|
|
429
|
+
self._manifest_tb: Optional[pa.Table] = None
|
|
430
|
+
self._columns: List[Tuple[str, str]] = []
|
|
436
431
|
|
|
437
432
|
@property
|
|
438
433
|
def map_link(self):
|
|
@@ -590,147 +585,77 @@ class AtlasProjection:
|
|
|
590
585
|
self._schema = ipc.read_schema(io.BytesIO(content))
|
|
591
586
|
return self._schema
|
|
592
587
|
|
|
593
|
-
|
|
588
|
+
@property
|
|
589
|
+
def _registered_columns(self) -> List[Tuple[str, str]]:
|
|
594
590
|
"Returns [(field_name, sidecar_name), ...]"
|
|
595
|
-
|
|
591
|
+
if self._columns:
|
|
592
|
+
return self._columns
|
|
593
|
+
self._columns = []
|
|
596
594
|
for field in self.schema:
|
|
597
595
|
sidecar_name = json.loads(field.metadata.get(b"sidecar_name", b'""'))
|
|
598
|
-
if sidecar_name:
|
|
599
|
-
|
|
600
|
-
return
|
|
596
|
+
if sidecar_name is not None:
|
|
597
|
+
self._columns.append((field.name, sidecar_name))
|
|
598
|
+
return self._columns
|
|
601
599
|
|
|
602
|
-
|
|
600
|
+
@property
|
|
601
|
+
def _manifest(self) -> pa.Table:
|
|
603
602
|
"""
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
overwrite: If True then overwrite web tile files.
|
|
608
|
-
|
|
609
|
-
Returns:
|
|
610
|
-
An Arrow table containing information for all data points in the index.
|
|
603
|
+
Returns the tile manifest for the projection.
|
|
604
|
+
Tile manifest is in quadtree order. All quadtree operations should
|
|
605
|
+
depend on tile manifest to ensure consistency.
|
|
611
606
|
"""
|
|
612
|
-
if self.
|
|
613
|
-
return self.
|
|
614
|
-
self._download_large_feather(overwrite=overwrite)
|
|
615
|
-
tbs = []
|
|
616
|
-
root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True)
|
|
617
|
-
try:
|
|
618
|
-
sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()])
|
|
619
|
-
except KeyError:
|
|
620
|
-
sidecars = set([])
|
|
621
|
-
sidecars |= set(sidecar_name for (_, sidecar_name) in self._registered_sidecars())
|
|
622
|
-
for path in self._tiles_in_order():
|
|
623
|
-
tb = pa.feather.read_table(path, memory_map=True) # type: ignore
|
|
624
|
-
for sidecar_file in sidecars:
|
|
625
|
-
carfile = pa.feather.read_table( # type: ignore
|
|
626
|
-
path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True
|
|
627
|
-
)
|
|
628
|
-
for col in carfile.column_names:
|
|
629
|
-
tb = tb.append_column(col, carfile[col])
|
|
630
|
-
tbs.append(tb)
|
|
631
|
-
self._tile_data = pa.concat_tables(tbs)
|
|
607
|
+
if self._manifest_tb is not None:
|
|
608
|
+
return self._manifest_tb
|
|
632
609
|
|
|
633
|
-
|
|
610
|
+
manifest_path = self.tile_destination / "manifest.feather"
|
|
611
|
+
manifest_url = (
|
|
612
|
+
self.dataset.atlas_api_path
|
|
613
|
+
+ f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/manifest.feather"
|
|
614
|
+
)
|
|
634
615
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
def _tiles_in_order(self, *, coords_only: Literal[True]) -> Iterator[Tuple[int, int, int]]: ...
|
|
639
|
-
@overload
|
|
640
|
-
def _tiles_in_order(self, *, coords_only: bool) -> Iterator[Any]: ...
|
|
616
|
+
download_feather(manifest_url, manifest_path, headers=self.dataset.header, overwrite=False)
|
|
617
|
+
self._manifest_tb = feather.read_table(manifest_path, memory_map=False)
|
|
618
|
+
return self._manifest_tb
|
|
641
619
|
|
|
642
|
-
def
|
|
620
|
+
def _get_sidecar_from_field(self, field: str) -> str:
|
|
643
621
|
"""
|
|
644
|
-
Returns
|
|
645
|
-
A list of all tiles in the projection in a fixed order so that all
|
|
646
|
-
datasets are guaranteed to be aligned.
|
|
647
|
-
"""
|
|
648
|
-
|
|
649
|
-
def children(z, x, y):
|
|
650
|
-
# This is the definition of a quadtree.
|
|
651
|
-
return [
|
|
652
|
-
(z + 1, x * 2, y * 2),
|
|
653
|
-
(z + 1, x * 2 + 1, y * 2),
|
|
654
|
-
(z + 1, x * 2, y * 2 + 1),
|
|
655
|
-
(z + 1, x * 2 + 1, y * 2 + 1),
|
|
656
|
-
]
|
|
657
|
-
|
|
658
|
-
# start with the root
|
|
659
|
-
paths = [(0, 0, 0)]
|
|
660
|
-
# Pop off the front, extend the back (breadth first traversal)
|
|
661
|
-
while len(paths) > 0:
|
|
662
|
-
z, x, y = paths.pop(0)
|
|
663
|
-
path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix(".feather")
|
|
664
|
-
if path.exists():
|
|
665
|
-
if coords_only:
|
|
666
|
-
yield (z, x, y)
|
|
667
|
-
else:
|
|
668
|
-
yield path
|
|
669
|
-
paths.extend(children(z, x, y)) # pyright: ignore
|
|
622
|
+
Returns the sidecar name for a given field.
|
|
670
623
|
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
624
|
+
Args:
|
|
625
|
+
field: the name of the field
|
|
626
|
+
"""
|
|
627
|
+
for f, sidecar in self._registered_columns:
|
|
628
|
+
if field == f:
|
|
629
|
+
return sidecar
|
|
630
|
+
raise ValueError(f"Field {field} not found in registered columns.")
|
|
674
631
|
|
|
675
|
-
def
|
|
632
|
+
def _download_sidecar(self, sidecar_name, overwrite: bool = False) -> List[Path]:
|
|
676
633
|
"""
|
|
677
|
-
Downloads the
|
|
634
|
+
Downloads sidecar files from the quadtree
|
|
678
635
|
Args:
|
|
636
|
+
sidecar_name: the name of the sidecar file
|
|
679
637
|
overwrite: if True then overwrite existing feather files.
|
|
680
638
|
|
|
681
639
|
Returns:
|
|
682
|
-
|
|
683
|
-
"""
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
download_attempt += 1
|
|
702
|
-
if not path.exists() or overwrite:
|
|
703
|
-
data = requests.get(root + quad, headers=self.dataset.header)
|
|
704
|
-
readable = io.BytesIO(data.content)
|
|
705
|
-
readable.seek(0)
|
|
706
|
-
tb = feather.read_table(readable, memory_map=True)
|
|
707
|
-
path.parent.mkdir(parents=True, exist_ok=True)
|
|
708
|
-
feather.write_feather(tb, path)
|
|
709
|
-
try:
|
|
710
|
-
schema = ipc.open_file(path).schema
|
|
711
|
-
download_success = True
|
|
712
|
-
except pa.ArrowInvalid:
|
|
713
|
-
path.unlink(missing_ok=True)
|
|
714
|
-
|
|
715
|
-
if not download_success or schema is None:
|
|
716
|
-
raise Exception(f"Failed to download tiles. Aborting...")
|
|
717
|
-
|
|
718
|
-
if sidecars is None and b"sidecars" in schema.metadata:
|
|
719
|
-
# Grab just the filenames
|
|
720
|
-
sidecars = set([v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()])
|
|
721
|
-
elif sidecars is None:
|
|
722
|
-
sidecars = set()
|
|
723
|
-
if not "." in rawquad:
|
|
724
|
-
for sidecar in sidecars | registered_sidecars:
|
|
725
|
-
# The sidecar loses the feather suffix because it's supposed to be raw.
|
|
726
|
-
quads.append(quad.replace(".feather", f".{sidecar}"))
|
|
727
|
-
if not schema.metadata or b"children" not in schema.metadata:
|
|
728
|
-
# Sidecars don't have children.
|
|
729
|
-
continue
|
|
730
|
-
kids = schema.metadata.get(b"children")
|
|
731
|
-
children = json.loads(kids)
|
|
732
|
-
quads.extend(children)
|
|
733
|
-
return all_quads
|
|
640
|
+
List of downloaded feather files.
|
|
641
|
+
"""
|
|
642
|
+
downloaded_files = []
|
|
643
|
+
sidecar_suffix = "feather"
|
|
644
|
+
if sidecar_name != "":
|
|
645
|
+
sidecar_suffix = f"{sidecar_name}.feather"
|
|
646
|
+
for key in tqdm(self._manifest["key"].to_pylist()):
|
|
647
|
+
sidecar_path = self.tile_destination / f"{key}.{sidecar_suffix}"
|
|
648
|
+
sidecar_url = (
|
|
649
|
+
self.dataset.atlas_api_path
|
|
650
|
+
+ f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/{key}.{sidecar_suffix}"
|
|
651
|
+
)
|
|
652
|
+
download_feather(sidecar_url, sidecar_path, headers=self.dataset.header, overwrite=overwrite)
|
|
653
|
+
downloaded_files.append(sidecar_path)
|
|
654
|
+
return downloaded_files
|
|
655
|
+
|
|
656
|
+
@property
|
|
657
|
+
def tile_destination(self):
|
|
658
|
+
return Path("~/.nomic/cache", self.id).expanduser()
|
|
734
659
|
|
|
735
660
|
@property
|
|
736
661
|
def datum_id_field(self):
|
|
@@ -1160,7 +1085,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1160
1085
|
|
|
1161
1086
|
build_template = {}
|
|
1162
1087
|
if self.modality == "embedding":
|
|
1163
|
-
if topic_model.
|
|
1088
|
+
if topic_model.topic_label_field is None:
|
|
1164
1089
|
logger.warning(
|
|
1165
1090
|
"You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics."
|
|
1166
1091
|
)
|
|
@@ -1188,7 +1113,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1188
1113
|
"topic_model_hyperparameters": json.dumps(
|
|
1189
1114
|
{
|
|
1190
1115
|
"build_topic_model": topic_model.build_topic_model,
|
|
1191
|
-
"community_description_target_field": topic_model.
|
|
1116
|
+
"community_description_target_field": topic_model.topic_label_field, # TODO change key to topic_label_field post v0.0.85
|
|
1192
1117
|
"cluster_method": topic_model.cluster_method,
|
|
1193
1118
|
"enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
|
|
1194
1119
|
}
|
|
@@ -1253,7 +1178,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1253
1178
|
"topic_model_hyperparameters": json.dumps(
|
|
1254
1179
|
{
|
|
1255
1180
|
"build_topic_model": topic_model.build_topic_model,
|
|
1256
|
-
"community_description_target_field": indexed_field,
|
|
1181
|
+
"community_description_target_field": indexed_field, # TODO change key to topic_label_field post v0.0.85
|
|
1257
1182
|
"cluster_method": topic_model.build_topic_model,
|
|
1258
1183
|
"enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
|
|
1259
1184
|
}
|
|
@@ -4,11 +4,12 @@ import random
|
|
|
4
4
|
import sys
|
|
5
5
|
from io import BytesIO
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Optional
|
|
7
|
+
from typing import Optional, Union
|
|
8
8
|
from uuid import UUID
|
|
9
9
|
|
|
10
10
|
import pyarrow as pa
|
|
11
11
|
import requests
|
|
12
|
+
from pyarrow import ipc
|
|
12
13
|
|
|
13
14
|
nouns = [
|
|
14
15
|
"newton",
|
|
@@ -241,10 +242,49 @@ def get_object_size_in_bytes(obj):
|
|
|
241
242
|
|
|
242
243
|
# Helpful function for downloading feather files
|
|
243
244
|
# Best for small feather files
|
|
244
|
-
def download_feather(
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
245
|
+
def download_feather(
|
|
246
|
+
url: Union[str, Path], path: Path, headers: Optional[dict] = None, num_attempts=1, overwrite=False
|
|
247
|
+
) -> pa.Schema:
|
|
248
|
+
"""
|
|
249
|
+
Download a feather file from a URL to a local path.
|
|
250
|
+
Returns the schema of feather file if successful.
|
|
251
|
+
|
|
252
|
+
Parameters:
|
|
253
|
+
url (str): URL to download feather file from.
|
|
254
|
+
path (Path): Local path to save feather file to.
|
|
255
|
+
headers (dict): Optional headers to include in request.
|
|
256
|
+
num_attempts (int): Number of download attempts before raising an error.
|
|
257
|
+
overwrite (bool): Whether to overwrite existing file.
|
|
258
|
+
Returns:
|
|
259
|
+
Feather schema.
|
|
260
|
+
"""
|
|
261
|
+
assert num_attempts > 0, "Num attempts must be greater than 0"
|
|
262
|
+
download_attempt = 0
|
|
263
|
+
download_success = False
|
|
264
|
+
schema = None
|
|
265
|
+
while download_attempt < num_attempts and not download_success:
|
|
266
|
+
download_attempt += 1
|
|
267
|
+
if not path.exists() or overwrite:
|
|
268
|
+
# Attempt download
|
|
269
|
+
try:
|
|
270
|
+
data = requests.get(str(url), headers=headers)
|
|
271
|
+
readable = BytesIO(data.content)
|
|
272
|
+
readable.seek(0)
|
|
273
|
+
tb = pa.feather.read_table(readable, memory_map=False) # type: ignore
|
|
274
|
+
schema = tb.schema
|
|
275
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
276
|
+
pa.feather.write_feather(tb, path) # type: ignore
|
|
277
|
+
download_success = True
|
|
278
|
+
except pa.ArrowInvalid:
|
|
279
|
+
# failed try again
|
|
280
|
+
path.unlink(missing_ok=True)
|
|
281
|
+
else:
|
|
282
|
+
# Load existing file
|
|
283
|
+
try:
|
|
284
|
+
schema = ipc.open_file(path).schema
|
|
285
|
+
download_success = True
|
|
286
|
+
except pa.ArrowInvalid:
|
|
287
|
+
path.unlink(missing_ok=True)
|
|
288
|
+
if not download_success or schema is None:
|
|
289
|
+
raise ValueError(f"Failed to download feather file from {url} after {num_attempts} attempts.")
|
|
290
|
+
return schema
|
|
@@ -8,7 +8,7 @@ description = "The official Nomic python client."
|
|
|
8
8
|
|
|
9
9
|
setup(
|
|
10
10
|
name="nomic",
|
|
11
|
-
version="3.0.
|
|
11
|
+
version="3.0.33",
|
|
12
12
|
url="https://github.com/nomic-ai/nomic",
|
|
13
13
|
description=description,
|
|
14
14
|
long_description=description,
|
|
@@ -43,7 +43,7 @@ setup(
|
|
|
43
43
|
],
|
|
44
44
|
"dev": [
|
|
45
45
|
"nomic[all]",
|
|
46
|
-
"black",
|
|
46
|
+
"black==24.3.0",
|
|
47
47
|
"coverage",
|
|
48
48
|
"pylint",
|
|
49
49
|
"pytest",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|