nomic 3.5.0__tar.gz → 3.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nomic-3.5.0 → nomic-3.5.2}/PKG-INFO +1 -1
- {nomic-3.5.0 → nomic-3.5.2}/nomic/atlas.py +3 -39
- {nomic-3.5.0 → nomic-3.5.2}/nomic/data_operations.py +337 -87
- {nomic-3.5.0 → nomic-3.5.2}/nomic/dataset.py +138 -156
- {nomic-3.5.0 → nomic-3.5.2}/nomic.egg-info/PKG-INFO +1 -1
- {nomic-3.5.0 → nomic-3.5.2}/setup.py +1 -1
- {nomic-3.5.0 → nomic-3.5.2}/README.md +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/__init__.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/aws/__init__.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/aws/sagemaker.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/cli.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/data_inference.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/embed.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/pl_callbacks/__init__.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/pl_callbacks/pl_callback.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/settings.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic/utils.py +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic.egg-info/SOURCES.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic.egg-info/dependency_links.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic.egg-info/entry_points.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic.egg-info/requires.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/nomic.egg-info/top_level.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/pyproject.toml +0 -0
- {nomic-3.5.0 → nomic-3.5.2}/setup.cfg +0 -0
|
@@ -3,7 +3,6 @@ This class allows for programmatic interactions with Atlas - Nomic's neural data
|
|
|
3
3
|
or in a Jupyter Notebook to organize and interact with your unstructured data.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import uuid
|
|
7
6
|
from typing import Dict, Iterable, List, Optional, Union
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
@@ -42,7 +41,7 @@ def map_data(
|
|
|
42
41
|
embeddings: An [N,d] numpy array containing the N embeddings to add.
|
|
43
42
|
identifier: A name for your dataset that is used to generate the dataset identifier. A unique name will be chosen if not supplied.
|
|
44
43
|
description: The description of your dataset
|
|
45
|
-
id_field: Specify
|
|
44
|
+
id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
|
|
46
45
|
is_public: Should the dataset be accessible outside your Nomic Atlas organization.
|
|
47
46
|
indexed_field: The text field from the dataset that will be used to create embeddings, which determines the layout of the data map in Atlas. Required for text data but won't have an impact if uploading embeddings or image blobs.
|
|
48
47
|
projection: Options for configuring the 2D projection algorithm.
|
|
@@ -86,9 +85,6 @@ def map_data(
|
|
|
86
85
|
# default to vision v1.5
|
|
87
86
|
embedding_model = NomicEmbedOptions(model="nomic-embed-vision-v1.5")
|
|
88
87
|
|
|
89
|
-
if id_field is None:
|
|
90
|
-
id_field = ATLAS_DEFAULT_ID_FIELD
|
|
91
|
-
|
|
92
88
|
project_name = get_random_name()
|
|
93
89
|
|
|
94
90
|
dataset_name = project_name
|
|
@@ -100,38 +96,6 @@ def map_data(
|
|
|
100
96
|
if description:
|
|
101
97
|
description = description
|
|
102
98
|
|
|
103
|
-
# no metadata was specified
|
|
104
|
-
added_id_field = False
|
|
105
|
-
|
|
106
|
-
if data is None:
|
|
107
|
-
added_id_field = True
|
|
108
|
-
if embeddings is not None:
|
|
109
|
-
data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(embeddings))]
|
|
110
|
-
elif blobs is not None:
|
|
111
|
-
data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(blobs))]
|
|
112
|
-
else:
|
|
113
|
-
raise ValueError("You must specify either data, embeddings, or blobs")
|
|
114
|
-
|
|
115
|
-
if id_field == ATLAS_DEFAULT_ID_FIELD and data is not None:
|
|
116
|
-
if isinstance(data, list) and id_field not in data[0]:
|
|
117
|
-
added_id_field = True
|
|
118
|
-
for i in range(len(data)):
|
|
119
|
-
# do not modify object the user passed in - also ensures IDs are unique if two input datums are the same *object*
|
|
120
|
-
data[i] = data[i].copy()
|
|
121
|
-
data[i][id_field] = b64int(i)
|
|
122
|
-
elif isinstance(data, DataFrame) and id_field not in data.columns:
|
|
123
|
-
data[id_field] = [b64int(i) for i in range(data.shape[0])]
|
|
124
|
-
added_id_field = True
|
|
125
|
-
elif isinstance(data, pa.Table) and not id_field in data.column_names: # type: ignore
|
|
126
|
-
ids = pa.array([b64int(i) for i in range(len(data))])
|
|
127
|
-
data = data.append_column(id_field, ids) # type: ignore
|
|
128
|
-
added_id_field = True
|
|
129
|
-
elif id_field not in data[0]:
|
|
130
|
-
raise ValueError("map_data data must be a list of dicts, a pandas dataframe, or a pyarrow table")
|
|
131
|
-
|
|
132
|
-
if added_id_field:
|
|
133
|
-
logger.warning("An ID field was not specified in your data so one was generated for you in insertion order.")
|
|
134
|
-
|
|
135
99
|
dataset = AtlasDataset(
|
|
136
100
|
identifier=dataset_name, description=description, unique_id_field=id_field, is_public=is_public
|
|
137
101
|
)
|
|
@@ -202,7 +166,7 @@ def map_embeddings(
|
|
|
202
166
|
Args:
|
|
203
167
|
embeddings: An [N,d] numpy array containing the batch of N embeddings to add.
|
|
204
168
|
data: An [N,] element list of dictionaries containing metadata for each embedding.
|
|
205
|
-
id_field: Specify
|
|
169
|
+
id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
|
|
206
170
|
name: A name for your dataset. Specify in the format `organization/project` to create in a specific organization.
|
|
207
171
|
description: A description for your map.
|
|
208
172
|
is_public: Should this embedding map be public? Private maps can only be accessed by members of your organization.
|
|
@@ -250,7 +214,7 @@ def map_text(
|
|
|
250
214
|
Args:
|
|
251
215
|
data: An [N,] element iterable of dictionaries containing metadata for each embedding.
|
|
252
216
|
indexed_field: The name the data field containing the text your want to map.
|
|
253
|
-
id_field: Specify
|
|
217
|
+
id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
|
|
254
218
|
name: A name for your dataset. Specify in the format `organization/project` to create in a specific organization.
|
|
255
219
|
description: A description for your map.
|
|
256
220
|
build_topic_model: Builds a hierarchical topic model over your data to discover patterns.
|
|
@@ -25,7 +25,6 @@ class AtlasMapDuplicates:
|
|
|
25
25
|
|
|
26
26
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
27
27
|
self.projection = projection
|
|
28
|
-
self.id_field = self.projection.dataset.id_field
|
|
29
28
|
|
|
30
29
|
duplicate_columns = [
|
|
31
30
|
(field, sidecar)
|
|
@@ -40,7 +39,13 @@ class AtlasMapDuplicates:
|
|
|
40
39
|
|
|
41
40
|
self._duplicate_column = duplicate_columns[0]
|
|
42
41
|
self._cluster_column = cluster_columns[0]
|
|
42
|
+
self.duplicate_field = None
|
|
43
|
+
self.cluster_field = None
|
|
43
44
|
self._tb = None
|
|
45
|
+
self._has_unique_id_field = (
|
|
46
|
+
"unique_id_field" in self.projection.dataset.meta
|
|
47
|
+
and self.projection.dataset.meta["unique_id_field"] is not None
|
|
48
|
+
)
|
|
44
49
|
|
|
45
50
|
def _load_duplicates(self):
|
|
46
51
|
"""
|
|
@@ -51,30 +56,80 @@ class AtlasMapDuplicates:
|
|
|
51
56
|
self.duplicate_field = self._duplicate_column[0].lstrip("_")
|
|
52
57
|
self.cluster_field = self._cluster_column[0].lstrip("_")
|
|
53
58
|
logger.info("Loading duplicates")
|
|
59
|
+
id_field_name = "_position_index"
|
|
60
|
+
if self._has_unique_id_field:
|
|
61
|
+
id_field_name = self.projection.dataset.meta["unique_id_field"]
|
|
62
|
+
|
|
54
63
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
55
|
-
# Use datum id as root table
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
64
|
+
# Use datum id as root table if available, otherwise create synthetic IDs
|
|
65
|
+
if self._has_unique_id_field:
|
|
66
|
+
try:
|
|
67
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
68
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
69
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
70
|
+
# Create a synthetic ID table
|
|
71
|
+
tb = self._create_synthetic_id_table(key, duplicate_sidecar)
|
|
72
|
+
else:
|
|
73
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
74
|
+
tb = self._create_synthetic_id_table(key, duplicate_sidecar)
|
|
60
75
|
|
|
76
|
+
path = self.projection.tile_destination
|
|
61
77
|
if duplicate_sidecar == "":
|
|
62
78
|
path = path / Path(key).with_suffix(".feather")
|
|
63
79
|
else:
|
|
64
80
|
path = path / Path(key).with_suffix(f".{duplicate_sidecar}.feather")
|
|
65
81
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
82
|
+
try:
|
|
83
|
+
duplicate_tb = feather.read_table(path, memory_map=True)
|
|
84
|
+
for field in (self._duplicate_column[0], self._cluster_column[0]):
|
|
85
|
+
tb = tb.append_column(field, duplicate_tb[field])
|
|
86
|
+
tbs.append(tb)
|
|
87
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
88
|
+
logger.warning(f"Error loading duplicate data for key {key}: {e}")
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
if not tbs:
|
|
92
|
+
raise ValueError("No duplicate data could be loaded. Duplicate data files may be missing or corrupt.")
|
|
93
|
+
|
|
94
|
+
self._tb = pa.concat_tables(tbs).rename_columns([id_field_name, self.duplicate_field, self.cluster_field])
|
|
95
|
+
|
|
96
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
97
|
+
"""
|
|
98
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
99
|
+
or when unique_id_field isn't specified.
|
|
100
|
+
"""
|
|
101
|
+
# Try to determine the size of the table by loading the duplicate sidecar
|
|
102
|
+
path = self.projection.tile_destination
|
|
103
|
+
if sidecar == "":
|
|
104
|
+
path = path / Path(key).with_suffix(".feather")
|
|
105
|
+
else:
|
|
106
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
sidecar_tb = feather.read_table(path, memory_map=True)
|
|
110
|
+
size = len(sidecar_tb)
|
|
111
|
+
# Create a table with position indices as IDs
|
|
112
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
113
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
116
|
+
# Return an empty table as fallback
|
|
117
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
71
118
|
|
|
72
119
|
def _download_duplicates(self):
|
|
73
120
|
"""
|
|
74
121
|
Downloads the feather tree for duplicates.
|
|
75
122
|
"""
|
|
76
123
|
logger.info("Downloading duplicates")
|
|
77
|
-
|
|
124
|
+
|
|
125
|
+
# Only download datum_id if we have a unique ID field
|
|
126
|
+
if self._has_unique_id_field:
|
|
127
|
+
try:
|
|
128
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
129
|
+
except ValueError as e:
|
|
130
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
131
|
+
self._has_unique_id_field = False
|
|
132
|
+
|
|
78
133
|
assert self._cluster_column[1] == self._duplicate_column[1], "Cluster and duplicate should be in same sidecar"
|
|
79
134
|
self.projection._download_sidecar(self._duplicate_column[1], overwrite=False)
|
|
80
135
|
|
|
@@ -100,17 +155,24 @@ class AtlasMapDuplicates:
|
|
|
100
155
|
|
|
101
156
|
def deletion_candidates(self) -> List[str]:
|
|
102
157
|
"""
|
|
103
|
-
|
|
104
158
|
Returns:
|
|
105
159
|
The ids for all data points which are semantic duplicates and are candidates for being deleted from the dataset. If you remove these data points from your dataset, your dataset will be semantically deduplicated.
|
|
106
160
|
"""
|
|
107
|
-
|
|
161
|
+
id_field_name = "_position_index"
|
|
162
|
+
if self._has_unique_id_field:
|
|
163
|
+
id_field_name = self.projection.dataset.meta["unique_id_field"]
|
|
164
|
+
|
|
165
|
+
dupes = self.tb[id_field_name].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
|
|
108
166
|
return dupes.to_pylist()
|
|
109
167
|
|
|
110
168
|
def __repr__(self) -> str:
|
|
169
|
+
id_field_name = "_position_index"
|
|
170
|
+
if self._has_unique_id_field:
|
|
171
|
+
id_field_name = self.projection.dataset.meta["unique_id_field"]
|
|
172
|
+
|
|
111
173
|
repr = f"===Atlas Duplicates for ({self.projection})\n"
|
|
112
174
|
duplicate_count = len(
|
|
113
|
-
self.tb[
|
|
175
|
+
self.tb[id_field_name].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
|
|
114
176
|
)
|
|
115
177
|
cluster_count = len(self.tb[self.cluster_field].value_counts())
|
|
116
178
|
repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n"
|
|
@@ -125,7 +187,6 @@ class AtlasMapTopics:
|
|
|
125
187
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
126
188
|
self.projection = projection
|
|
127
189
|
self.dataset = projection.dataset
|
|
128
|
-
self.id_field = self.projection.dataset.id_field
|
|
129
190
|
self._metadata = None
|
|
130
191
|
self._hierarchy = None
|
|
131
192
|
self._topic_columns = [
|
|
@@ -134,6 +195,9 @@ class AtlasMapTopics:
|
|
|
134
195
|
assert len(self._topic_columns) > 0, "Topic modeling has not yet been run on this map."
|
|
135
196
|
self.depth = len(self._topic_columns)
|
|
136
197
|
self._tb = None
|
|
198
|
+
self._has_unique_id_field = (
|
|
199
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
200
|
+
)
|
|
137
201
|
|
|
138
202
|
def _load_topics(self):
|
|
139
203
|
"""
|
|
@@ -149,45 +213,96 @@ class AtlasMapTopics:
|
|
|
149
213
|
# Should just be one sidecar
|
|
150
214
|
topic_sidecar = set([sidecar for _, sidecar in self._topic_columns]).pop()
|
|
151
215
|
logger.info("Loading topics")
|
|
216
|
+
id_field_name = "_position_index"
|
|
217
|
+
if self._has_unique_id_field:
|
|
218
|
+
id_field_name = self.dataset.meta["unique_id_field"]
|
|
219
|
+
|
|
152
220
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
153
|
-
# Use datum id as root table
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
221
|
+
# Use datum id as root table if available, otherwise create a synthetic index
|
|
222
|
+
if self._has_unique_id_field:
|
|
223
|
+
try:
|
|
224
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
225
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
226
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
227
|
+
# If datum_id file doesn't exist, create a table with a position index
|
|
228
|
+
logger.warning(f"Datum ID file not found for key {key}. Creating synthetic IDs.")
|
|
229
|
+
tb = self._create_synthetic_id_table(key, topic_sidecar)
|
|
230
|
+
else:
|
|
231
|
+
# Create a table with a position index when no unique_id_field is available
|
|
232
|
+
tb = self._create_synthetic_id_table(key, topic_sidecar)
|
|
233
|
+
|
|
157
234
|
path = self.projection.tile_destination
|
|
158
235
|
if topic_sidecar == "":
|
|
159
236
|
path = path / Path(key).with_suffix(".feather")
|
|
160
237
|
else:
|
|
161
238
|
path = path / Path(key).with_suffix(f".{topic_sidecar}.feather")
|
|
162
239
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
topic_ids_to_label
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
240
|
+
try:
|
|
241
|
+
topic_tb = feather.read_table(path, memory_map=True)
|
|
242
|
+
# Do this in depth order
|
|
243
|
+
for d in range(1, self.depth + 1):
|
|
244
|
+
column = f"_topic_depth_{d}"
|
|
245
|
+
if integer_topics:
|
|
246
|
+
column = f"_topic_depth_{d}_int"
|
|
247
|
+
topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
|
|
248
|
+
assert label_df is not None
|
|
249
|
+
topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
|
|
250
|
+
topic_ids_to_label, on="topic_id", how="right"
|
|
251
|
+
)
|
|
252
|
+
new_column = f"_topic_depth_{d}"
|
|
253
|
+
tb = tb.append_column(
|
|
254
|
+
new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
|
|
258
|
+
tbs.append(tb)
|
|
259
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
260
|
+
logger.warning(f"Error loading topic data for key {key}: {e}")
|
|
261
|
+
continue
|
|
181
262
|
|
|
182
|
-
|
|
263
|
+
if not tbs:
|
|
264
|
+
raise ValueError("No topic data could be loaded. Topic data files may be missing or corrupt.")
|
|
265
|
+
|
|
266
|
+
renamed_columns = [id_field_name] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
|
|
183
267
|
self._tb = pa.concat_tables(tbs).rename_columns(renamed_columns)
|
|
184
268
|
|
|
269
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
270
|
+
"""
|
|
271
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
272
|
+
or when unique_id_field isn't specified.
|
|
273
|
+
"""
|
|
274
|
+
# Try to determine the size of the table by loading the topic sidecar
|
|
275
|
+
path = self.projection.tile_destination
|
|
276
|
+
if sidecar == "":
|
|
277
|
+
path = path / Path(key).with_suffix(".feather")
|
|
278
|
+
else:
|
|
279
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
280
|
+
|
|
281
|
+
try:
|
|
282
|
+
topic_tb = feather.read_table(path, memory_map=True)
|
|
283
|
+
size = len(topic_tb)
|
|
284
|
+
# Create a table with position indices as IDs
|
|
285
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
286
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
289
|
+
# Return an empty table as fallback
|
|
290
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
291
|
+
|
|
185
292
|
def _download_topics(self):
|
|
186
293
|
"""
|
|
187
294
|
Downloads the feather tree for topics.
|
|
188
295
|
"""
|
|
189
296
|
logger.info("Downloading topics")
|
|
190
|
-
|
|
297
|
+
|
|
298
|
+
# Only download datum_id if we have a unique ID field
|
|
299
|
+
if self._has_unique_id_field:
|
|
300
|
+
try:
|
|
301
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
302
|
+
except ValueError as e:
|
|
303
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
304
|
+
self._has_unique_id_field = False
|
|
305
|
+
|
|
191
306
|
topic_sidecars = set([sidecar for _, sidecar in self._topic_columns])
|
|
192
307
|
assert len(topic_sidecars) == 1, "Multiple topic sidecars found."
|
|
193
308
|
self.projection._download_sidecar(topic_sidecars.pop(), overwrite=False)
|
|
@@ -286,7 +401,10 @@ class AtlasMapTopics:
|
|
|
286
401
|
raise ValueError("Topic depth out of range.")
|
|
287
402
|
|
|
288
403
|
# Unique datum id column to aggregate
|
|
289
|
-
datum_id_col =
|
|
404
|
+
datum_id_col = "_position_index"
|
|
405
|
+
if "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None:
|
|
406
|
+
datum_id_col = self.dataset.meta["unique_id_field"]
|
|
407
|
+
|
|
290
408
|
df = self.df
|
|
291
409
|
|
|
292
410
|
topic_datum_dict = df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict()
|
|
@@ -333,8 +451,12 @@ class AtlasMapTopics:
|
|
|
333
451
|
A list of `{topic, count}` dictionaries, sorted from largest count to smallest count.
|
|
334
452
|
"""
|
|
335
453
|
data = AtlasMapData(self.projection, fields=[time_field])
|
|
336
|
-
|
|
337
|
-
|
|
454
|
+
id_field_name = "_position_index"
|
|
455
|
+
if "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None:
|
|
456
|
+
id_field_name = self.dataset.meta["unique_id_field"]
|
|
457
|
+
|
|
458
|
+
time_data = data.tb.select([id_field_name, time_field])
|
|
459
|
+
merged_tb = self.tb.join(time_data, id_field_name, join_type="inner").combine_chunks()
|
|
338
460
|
|
|
339
461
|
del time_data # free up memory
|
|
340
462
|
|
|
@@ -343,12 +465,12 @@ class AtlasMapTopics:
|
|
|
343
465
|
topic_densities = {}
|
|
344
466
|
for depth in range(1, self.depth + 1):
|
|
345
467
|
topic_column = f"topic_depth_{depth}"
|
|
346
|
-
topic_counts = merged_tb.group_by(topic_column).aggregate([(
|
|
468
|
+
topic_counts = merged_tb.group_by(topic_column).aggregate([(id_field_name, "count")]).to_pandas()
|
|
347
469
|
for _, row in topic_counts.iterrows():
|
|
348
470
|
topic = row[topic_column]
|
|
349
471
|
if topic not in topic_densities:
|
|
350
472
|
topic_densities[topic] = 0
|
|
351
|
-
topic_densities[topic] += row[
|
|
473
|
+
topic_densities[topic] += row[id_field_name + "_count"]
|
|
352
474
|
return topic_densities
|
|
353
475
|
|
|
354
476
|
def vector_search_topics(self, queries: np.ndarray, k: int = 32, depth: int = 3) -> Dict:
|
|
@@ -448,10 +570,12 @@ class AtlasMapEmbeddings:
|
|
|
448
570
|
|
|
449
571
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
450
572
|
self.projection = projection
|
|
451
|
-
self.id_field = self.projection.dataset.id_field
|
|
452
573
|
self.dataset = projection.dataset
|
|
453
574
|
self._tb: pa.Table = None
|
|
454
575
|
self._latent = None
|
|
576
|
+
self._has_unique_id_field = (
|
|
577
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
578
|
+
)
|
|
455
579
|
|
|
456
580
|
@property
|
|
457
581
|
def df(self):
|
|
@@ -481,23 +605,81 @@ class AtlasMapEmbeddings:
|
|
|
481
605
|
tbs = []
|
|
482
606
|
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
483
607
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
484
|
-
# Use datum id as root table
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
608
|
+
# Use datum id as root table if available, otherwise create synthetic IDs
|
|
609
|
+
if self._has_unique_id_field:
|
|
610
|
+
try:
|
|
611
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
612
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
613
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
614
|
+
# Create a synthetic ID table
|
|
615
|
+
tb = self._create_synthetic_id_table(key, coord_sidecar)
|
|
616
|
+
else:
|
|
617
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
618
|
+
tb = self._create_synthetic_id_table(key, coord_sidecar)
|
|
619
|
+
|
|
488
620
|
path = self.projection.tile_destination
|
|
489
621
|
if coord_sidecar == "":
|
|
490
622
|
path = path / Path(key).with_suffix(".feather")
|
|
491
623
|
else:
|
|
492
624
|
path = path / Path(key).with_suffix(f".{coord_sidecar}.feather")
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
625
|
+
|
|
626
|
+
try:
|
|
627
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
628
|
+
for col in carfile.column_names:
|
|
629
|
+
if col in ["x", "y"]:
|
|
630
|
+
tb = tb.append_column(col, carfile[col])
|
|
631
|
+
tbs.append(tb)
|
|
632
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
633
|
+
logger.warning(f"Error loading embedding data for key {key}: {e}")
|
|
634
|
+
continue
|
|
635
|
+
|
|
636
|
+
if not tbs:
|
|
637
|
+
raise ValueError("No embedding data could be loaded. Embedding data files may be missing or corrupt.")
|
|
638
|
+
|
|
498
639
|
self._tb = pa.concat_tables(tbs)
|
|
499
640
|
return self._tb
|
|
500
641
|
|
|
642
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
643
|
+
"""
|
|
644
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
645
|
+
or when unique_id_field isn't specified.
|
|
646
|
+
"""
|
|
647
|
+
# Try to determine the size of the table by loading the coordinate sidecar
|
|
648
|
+
path = self.projection.tile_destination
|
|
649
|
+
if sidecar == "":
|
|
650
|
+
path = path / Path(key).with_suffix(".feather")
|
|
651
|
+
else:
|
|
652
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
653
|
+
|
|
654
|
+
try:
|
|
655
|
+
coord_tb = feather.read_table(path, memory_map=True)
|
|
656
|
+
size = len(coord_tb)
|
|
657
|
+
# Create a table with position indices as IDs
|
|
658
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
659
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
660
|
+
except Exception as e:
|
|
661
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
662
|
+
# Return an empty table as fallback
|
|
663
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
664
|
+
|
|
665
|
+
def _download_projected(self) -> List[Path]:
|
|
666
|
+
"""
|
|
667
|
+
Downloads the feather tree for projection coordinates.
|
|
668
|
+
"""
|
|
669
|
+
logger.info("Downloading projected embeddings")
|
|
670
|
+
# Note that y coord should be in same sidecar
|
|
671
|
+
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
672
|
+
|
|
673
|
+
# Only download datum_id if we have a unique ID field
|
|
674
|
+
if self._has_unique_id_field:
|
|
675
|
+
try:
|
|
676
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
677
|
+
except ValueError as e:
|
|
678
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
679
|
+
self._has_unique_id_field = False
|
|
680
|
+
|
|
681
|
+
return self.projection._download_sidecar(coord_sidecar, overwrite=False)
|
|
682
|
+
|
|
501
683
|
@property
|
|
502
684
|
def projected(self) -> pd.DataFrame:
|
|
503
685
|
"""
|
|
@@ -531,16 +713,6 @@ class AtlasMapEmbeddings:
|
|
|
531
713
|
all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
|
|
532
714
|
return np.vstack(all_embeddings)
|
|
533
715
|
|
|
534
|
-
def _download_projected(self) -> List[Path]:
|
|
535
|
-
"""
|
|
536
|
-
Downloads the feather tree for projection coordinates.
|
|
537
|
-
"""
|
|
538
|
-
logger.info("Downloading projected embeddings")
|
|
539
|
-
# Note that y coord should be in same sidecar
|
|
540
|
-
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
541
|
-
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
542
|
-
return self.projection._download_sidecar(coord_sidecar, overwrite=False)
|
|
543
|
-
|
|
544
716
|
def _download_latent(self) -> List[Path]:
|
|
545
717
|
"""
|
|
546
718
|
Downloads the feather tree for embeddings.
|
|
@@ -670,12 +842,17 @@ class AtlasMapTags:
|
|
|
670
842
|
def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False): # type: ignore
|
|
671
843
|
self.projection = projection
|
|
672
844
|
self.dataset = projection.dataset
|
|
673
|
-
self.id_field = self.projection.dataset.id_field
|
|
674
845
|
# Pre-fetch datum ids first upon initialization
|
|
675
|
-
|
|
676
|
-
self.
|
|
677
|
-
|
|
678
|
-
|
|
846
|
+
self._has_unique_id_field = (
|
|
847
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
if self._has_unique_id_field:
|
|
851
|
+
try:
|
|
852
|
+
self.projection._download_sidecar("datum_id")
|
|
853
|
+
except Exception as e:
|
|
854
|
+
logger.warning(f"Failed to fetch datum ids: {e}. Will use synthetic IDs.")
|
|
855
|
+
self._has_unique_id_field = False
|
|
679
856
|
self.auto_cleanup = auto_cleanup
|
|
680
857
|
|
|
681
858
|
@property
|
|
@@ -748,20 +925,38 @@ class AtlasMapTags:
|
|
|
748
925
|
"""
|
|
749
926
|
tag_paths = self._download_tag(tag_name, overwrite=overwrite)
|
|
750
927
|
datum_ids = []
|
|
928
|
+
id_field_name = "_position_index"
|
|
929
|
+
if self._has_unique_id_field:
|
|
930
|
+
id_field_name = self.dataset.meta["unique_id_field"]
|
|
931
|
+
|
|
751
932
|
for path in tag_paths:
|
|
752
933
|
tb = feather.read_table(path)
|
|
753
934
|
last_coord = path.name.split(".")[0]
|
|
754
|
-
|
|
755
|
-
|
|
935
|
+
|
|
936
|
+
# Get ID information - either from datum_id file or create synthetic IDs
|
|
937
|
+
if self._has_unique_id_field:
|
|
938
|
+
try:
|
|
939
|
+
tile_path = path.with_name(last_coord + ".datum_id.feather")
|
|
940
|
+
tile_tb = feather.read_table(tile_path).select([id_field_name])
|
|
941
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
942
|
+
# Create synthetic IDs if datum_id file not found
|
|
943
|
+
size = len(tb)
|
|
944
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
945
|
+
tile_tb = pa.Table.from_arrays([pa.array(position_indices)], names=[id_field_name])
|
|
946
|
+
else:
|
|
947
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
948
|
+
size = len(tb)
|
|
949
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
950
|
+
tile_tb = pa.Table.from_arrays([pa.array(position_indices)], names=[id_field_name])
|
|
756
951
|
|
|
757
952
|
if "all_set" in tb.column_names:
|
|
758
953
|
if tb["all_set"][0].as_py() == True:
|
|
759
|
-
datum_ids.extend(tile_tb[
|
|
954
|
+
datum_ids.extend(tile_tb[id_field_name].to_pylist())
|
|
760
955
|
else:
|
|
761
956
|
# filter on rows
|
|
762
957
|
try:
|
|
763
|
-
tb = tb.append_column(
|
|
764
|
-
datum_ids.extend(tb.filter(pc.field("bitmask") == True)[
|
|
958
|
+
tb = tb.append_column(id_field_name, tile_tb[id_field_name])
|
|
959
|
+
datum_ids.extend(tb.filter(pc.field("bitmask") == True)[id_field_name].to_pylist())
|
|
765
960
|
except Exception as e:
|
|
766
961
|
raise Exception(f"Failed to fetch datums in tag. {e}")
|
|
767
962
|
return datum_ids
|
|
@@ -849,7 +1044,6 @@ class AtlasMapData:
|
|
|
849
1044
|
def __init__(self, projection: "AtlasProjection", fields=None): # type: ignore
|
|
850
1045
|
self.projection = projection
|
|
851
1046
|
self.dataset = projection.dataset
|
|
852
|
-
self.id_field = self.projection.dataset.id_field
|
|
853
1047
|
if fields is None:
|
|
854
1048
|
# TODO: fall back on something more reliable here
|
|
855
1049
|
self.fields = self.dataset.dataset_fields
|
|
@@ -858,6 +1052,9 @@ class AtlasMapData:
|
|
|
858
1052
|
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
|
|
859
1053
|
self.fields = fields
|
|
860
1054
|
self._tb = None
|
|
1055
|
+
self._has_unique_id_field = (
|
|
1056
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
1057
|
+
)
|
|
861
1058
|
|
|
862
1059
|
def _load_data(self, data_columns: List[Tuple[str, str]]):
|
|
863
1060
|
"""
|
|
@@ -871,24 +1068,67 @@ class AtlasMapData:
|
|
|
871
1068
|
sidecars_to_load = set([sidecar for _, sidecar in data_columns if sidecar != "datum_id"])
|
|
872
1069
|
logger.info("Loading data")
|
|
873
1070
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
874
|
-
# Use datum id as root table
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
1071
|
+
# Use datum id as root table if available, otherwise create synthetic IDs
|
|
1072
|
+
if self._has_unique_id_field:
|
|
1073
|
+
try:
|
|
1074
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
1075
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
1076
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
1077
|
+
# Create a synthetic ID table
|
|
1078
|
+
# Using the first sidecar to determine table size
|
|
1079
|
+
first_sidecar = next(iter(sidecars_to_load)) if sidecars_to_load else ""
|
|
1080
|
+
tb = self._create_synthetic_id_table(key, first_sidecar)
|
|
1081
|
+
else:
|
|
1082
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
1083
|
+
first_sidecar = next(iter(sidecars_to_load)) if sidecars_to_load else ""
|
|
1084
|
+
tb = self._create_synthetic_id_table(key, first_sidecar)
|
|
1085
|
+
|
|
878
1086
|
for sidecar in sidecars_to_load:
|
|
879
1087
|
path = self.projection.tile_destination
|
|
880
1088
|
if sidecar == "":
|
|
881
1089
|
path = path / Path(key).with_suffix(".feather")
|
|
882
1090
|
else:
|
|
883
1091
|
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
1092
|
+
|
|
1093
|
+
try:
|
|
1094
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
1095
|
+
for col in carfile.column_names:
|
|
1096
|
+
if col in self.fields:
|
|
1097
|
+
tb = tb.append_column(col, carfile[col])
|
|
1098
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
1099
|
+
logger.warning(f"Error loading data for key {key}, sidecar {sidecar}: {e}")
|
|
1100
|
+
continue
|
|
1101
|
+
|
|
888
1102
|
tbs.append(tb)
|
|
889
1103
|
|
|
1104
|
+
if not tbs:
|
|
1105
|
+
raise ValueError("No data could be loaded. Data files may be missing or corrupt.")
|
|
1106
|
+
|
|
890
1107
|
self._tb = pa.concat_tables(tbs)
|
|
891
1108
|
|
|
1109
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
1110
|
+
"""
|
|
1111
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
1112
|
+
or when unique_id_field isn't specified.
|
|
1113
|
+
"""
|
|
1114
|
+
# Try to determine the size of the table by loading a sidecar file
|
|
1115
|
+
path = self.projection.tile_destination
|
|
1116
|
+
if sidecar == "":
|
|
1117
|
+
path = path / Path(key).with_suffix(".feather")
|
|
1118
|
+
else:
|
|
1119
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
1120
|
+
|
|
1121
|
+
try:
|
|
1122
|
+
sidecar_tb = feather.read_table(path, memory_map=True)
|
|
1123
|
+
size = len(sidecar_tb)
|
|
1124
|
+
# Create a table with position indices as IDs
|
|
1125
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
1126
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
1127
|
+
except Exception as e:
|
|
1128
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
1129
|
+
# Return an empty table as fallback
|
|
1130
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
1131
|
+
|
|
892
1132
|
def _download_data(self, fields: Optional[List[str]] = None) -> List[Tuple[str, str]]:
|
|
893
1133
|
"""
|
|
894
1134
|
Downloads the feather tree for user uploaded data.
|
|
@@ -902,16 +1142,26 @@ class AtlasMapData:
|
|
|
902
1142
|
logger.info("Downloading data")
|
|
903
1143
|
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
904
1144
|
|
|
905
|
-
# Download specified or all sidecar fields + always download datum_id
|
|
1145
|
+
# Download specified or all sidecar fields + always download datum_id if available
|
|
906
1146
|
data_columns_to_load = [
|
|
907
1147
|
(str(field), str(sidecar))
|
|
908
1148
|
for field, sidecar in self.projection._registered_columns
|
|
909
1149
|
if field[0] != "_" and ((field in fields) or sidecar == "datum_id")
|
|
910
1150
|
]
|
|
911
1151
|
|
|
912
|
-
#
|
|
913
|
-
|
|
1152
|
+
# Only download datum_id if we have a unique ID field
|
|
1153
|
+
if self._has_unique_id_field:
|
|
1154
|
+
try:
|
|
1155
|
+
self.projection._download_sidecar("datum_id")
|
|
1156
|
+
except ValueError as e:
|
|
1157
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
1158
|
+
self._has_unique_id_field = False
|
|
1159
|
+
|
|
1160
|
+
# Download all required sidecars for the fields
|
|
1161
|
+
sidecars_to_download = set(sidecar for _, sidecar in data_columns_to_load if sidecar != "datum_id")
|
|
1162
|
+
for sidecar in sidecars_to_download:
|
|
914
1163
|
self.projection._download_sidecar(sidecar)
|
|
1164
|
+
|
|
915
1165
|
return data_columns_to_load
|
|
916
1166
|
|
|
917
1167
|
@property
|
|
@@ -115,15 +115,12 @@ class AtlasClass(object):
|
|
|
115
115
|
|
|
116
116
|
return response.json()
|
|
117
117
|
|
|
118
|
-
def _validate_map_data_inputs(self, colorable_fields,
|
|
118
|
+
def _validate_map_data_inputs(self, colorable_fields, data_sample):
|
|
119
119
|
"""Validates inputs to map data calls."""
|
|
120
120
|
|
|
121
121
|
if not isinstance(colorable_fields, list):
|
|
122
122
|
raise ValueError("colorable_fields must be a list of fields")
|
|
123
123
|
|
|
124
|
-
if id_field in colorable_fields:
|
|
125
|
-
raise Exception(f"Cannot color by unique id field: {id_field}")
|
|
126
|
-
|
|
127
124
|
for field in colorable_fields:
|
|
128
125
|
if field not in data_sample:
|
|
129
126
|
raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.")
|
|
@@ -274,14 +271,12 @@ class AtlasClass(object):
|
|
|
274
271
|
"""
|
|
275
272
|
Private method. validates upload data against the dataset arrow schema, and associated other checks.
|
|
276
273
|
|
|
277
|
-
1. If unique_id_field is specified, validates that each datum has that field. If not, adds it and then notifies the user that it was added.
|
|
278
|
-
|
|
279
274
|
Args:
|
|
280
275
|
data: an arrow table.
|
|
281
276
|
project: the atlas dataset you are validating the data for.
|
|
282
277
|
|
|
283
278
|
Returns:
|
|
284
|
-
|
|
279
|
+
Validated pyarrow table.
|
|
285
280
|
"""
|
|
286
281
|
if not isinstance(data, pa.Table):
|
|
287
282
|
raise Exception("Invalid data type for upload: {}".format(type(data)))
|
|
@@ -295,8 +290,32 @@ class AtlasClass(object):
|
|
|
295
290
|
msg = "Must include embeddings in embedding dataset upload."
|
|
296
291
|
raise ValueError(msg)
|
|
297
292
|
|
|
298
|
-
|
|
299
|
-
|
|
293
|
+
# Check and validate ID field if specified
|
|
294
|
+
if "unique_id_field" in project.meta and project.meta["unique_id_field"] is not None:
|
|
295
|
+
id_field = project.meta["unique_id_field"]
|
|
296
|
+
|
|
297
|
+
# Check if ID field exists in data
|
|
298
|
+
if id_field not in data.column_names:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"Data must contain the ID column `{id_field}` as specified in dataset's unique_id_field"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Check for null values in ID field
|
|
304
|
+
if data[id_field].null_count > 0:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"As your unique id field, {id_field} must not contain null values, but {data[id_field].null_count} found."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Check ID field length (36 characters max)
|
|
310
|
+
if pa.types.is_string(data[id_field].type):
|
|
311
|
+
# Use a safer alternative to check string length
|
|
312
|
+
utf8_length_values = pc.utf8_length(data[id_field]) # type: ignore
|
|
313
|
+
max_length_scalar = pc.max(utf8_length_values) # type: ignore
|
|
314
|
+
max_length = max_length_scalar.as_py()
|
|
315
|
+
if max_length > 36:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"The id_field contains values greater than 36 characters. Atlas does not support id_fields longer than 36 characters."
|
|
318
|
+
)
|
|
300
319
|
|
|
301
320
|
seen = set()
|
|
302
321
|
for col in data.column_names:
|
|
@@ -313,11 +332,6 @@ class AtlasClass(object):
|
|
|
313
332
|
# filling in nulls, etc.
|
|
314
333
|
reformatted = {}
|
|
315
334
|
|
|
316
|
-
if data[project.id_field].null_count > 0:
|
|
317
|
-
raise ValueError(
|
|
318
|
-
f"{project.id_field} must not contain null values, but {data[project.id_field].null_count} found."
|
|
319
|
-
)
|
|
320
|
-
|
|
321
335
|
assert project.schema is not None, "Project schema not found."
|
|
322
336
|
|
|
323
337
|
for field in project.schema:
|
|
@@ -335,10 +349,15 @@ class AtlasClass(object):
|
|
|
335
349
|
f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version."
|
|
336
350
|
)
|
|
337
351
|
reformatted[field.name] = pc.fill_null(reformatted[field.name], "null")
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
352
|
+
|
|
353
|
+
# Check for empty strings and replace with "null"
|
|
354
|
+
# Separate the operations for better type checking
|
|
355
|
+
binary_length_values = pc.binary_length(reformatted[field.name]) # type: ignore
|
|
356
|
+
has_empty_strings = pc.equal(binary_length_values, 0) # type: ignore
|
|
357
|
+
if pc.any(has_empty_strings).as_py(): # type: ignore
|
|
358
|
+
mask = has_empty_strings.combine_chunks()
|
|
359
|
+
assert pa.types.is_boolean(mask.type)
|
|
360
|
+
reformatted[field.name] = pc.replace_with_mask(reformatted[field.name], mask, "null") # type: ignore
|
|
342
361
|
for field in data.schema:
|
|
343
362
|
if not field.name in reformatted:
|
|
344
363
|
if field.name == "_embeddings":
|
|
@@ -350,27 +369,12 @@ class AtlasClass(object):
|
|
|
350
369
|
if project.meta["insert_update_delete_lock"]:
|
|
351
370
|
raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.")
|
|
352
371
|
|
|
353
|
-
# The following two conditions should never occur given the above, but just in case...
|
|
354
|
-
assert project.id_field in data.column_names, f"Upload does not contain your specified id_field"
|
|
355
|
-
|
|
356
|
-
if not pa.types.is_string(data[project.id_field].type):
|
|
357
|
-
logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}")
|
|
358
|
-
data = data.drop([project.id_field]).append_column(
|
|
359
|
-
project.id_field, data[project.id_field].cast(pa.string())
|
|
360
|
-
)
|
|
361
|
-
|
|
362
372
|
for key in data.column_names:
|
|
363
373
|
if key.startswith("_"):
|
|
364
374
|
if key == "_embeddings" or key == "_blob_hash":
|
|
365
375
|
continue
|
|
366
376
|
raise ValueError("Metadata fields cannot start with _")
|
|
367
|
-
|
|
368
|
-
first_match = data.filter(
|
|
369
|
-
pa.compute.greater(pa.compute.utf8_length(data[project.id_field]), 36) # type: ignore
|
|
370
|
-
).to_pylist()[0][project.id_field]
|
|
371
|
-
raise ValueError(
|
|
372
|
-
f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters."
|
|
373
|
-
)
|
|
377
|
+
|
|
374
378
|
return data
|
|
375
379
|
|
|
376
380
|
def _get_organization(self, organization_slug=None, organization_id=None) -> Tuple[str, str]:
|
|
@@ -696,36 +700,6 @@ class AtlasProjection:
|
|
|
696
700
|
def tile_destination(self):
|
|
697
701
|
return Path("~/.nomic/cache", self.id).expanduser()
|
|
698
702
|
|
|
699
|
-
@property
|
|
700
|
-
def datum_id_field(self):
|
|
701
|
-
return self.dataset.meta["unique_id_field"]
|
|
702
|
-
|
|
703
|
-
def _get_atoms(self, ids: List[str]) -> List[Dict]:
|
|
704
|
-
"""
|
|
705
|
-
Retrieves atoms by id
|
|
706
|
-
|
|
707
|
-
Args:
|
|
708
|
-
ids: list of atom ids
|
|
709
|
-
|
|
710
|
-
Returns:
|
|
711
|
-
A dictionary containing the resulting atoms, keyed by atom id.
|
|
712
|
-
|
|
713
|
-
"""
|
|
714
|
-
|
|
715
|
-
if not isinstance(ids, list):
|
|
716
|
-
raise ValueError("You must specify a list of ids when getting data.")
|
|
717
|
-
|
|
718
|
-
response = requests.post(
|
|
719
|
-
self.dataset.atlas_api_path + "/v1/project/atoms/get",
|
|
720
|
-
headers=self.dataset.header,
|
|
721
|
-
json={"project_id": self.dataset.id, "index_id": self.atlas_index_id, "atom_ids": ids},
|
|
722
|
-
)
|
|
723
|
-
|
|
724
|
-
if response.status_code == 200:
|
|
725
|
-
return response.json()["atoms"]
|
|
726
|
-
else:
|
|
727
|
-
raise Exception(response.text)
|
|
728
|
-
|
|
729
703
|
|
|
730
704
|
class AtlasDataStream(AtlasClass):
|
|
731
705
|
def __init__(self, name: Optional[str] = "contrastors"):
|
|
@@ -766,7 +740,7 @@ class AtlasDataset(AtlasClass):
|
|
|
766
740
|
|
|
767
741
|
* **identifier** - The dataset identifier in the form `dataset` or `organization/dataset`. If no organization is passed, the organization tied to the API key you logged in to Nomic with will be used.
|
|
768
742
|
* **description** - A description for the dataset.
|
|
769
|
-
* **unique_id_field** -
|
|
743
|
+
* **unique_id_field** - A field that uniquely identifies each data point.
|
|
770
744
|
* **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
|
|
771
745
|
* **dataset_id** - An alternative way to load a dataset is by passing the dataset_id directly. This only works if a dataset exists.
|
|
772
746
|
"""
|
|
@@ -805,13 +779,6 @@ class AtlasDataset(AtlasClass):
|
|
|
805
779
|
dataset_id = dataset["id"]
|
|
806
780
|
|
|
807
781
|
if dataset_id is None: # if there is no existing project, make a new one.
|
|
808
|
-
if unique_id_field is None: # if not all parameters are specified, we weren't trying to make a project
|
|
809
|
-
raise ValueError(f"Dataset `{identifier}` does not exist.")
|
|
810
|
-
|
|
811
|
-
# if modality is None:
|
|
812
|
-
# raise ValueError("You must specify a modality when creating a new dataset.")
|
|
813
|
-
#
|
|
814
|
-
# assert modality in ['text', 'embedding'], "Modality must be either `text` or `embedding`"
|
|
815
782
|
assert identifier is not None
|
|
816
783
|
|
|
817
784
|
dataset_id = self._create_project(
|
|
@@ -840,7 +807,7 @@ class AtlasDataset(AtlasClass):
|
|
|
840
807
|
self,
|
|
841
808
|
identifier: str,
|
|
842
809
|
description: Optional[str],
|
|
843
|
-
unique_id_field: str,
|
|
810
|
+
unique_id_field: Optional[str] = None,
|
|
844
811
|
is_public: bool = True,
|
|
845
812
|
):
|
|
846
813
|
"""
|
|
@@ -852,7 +819,7 @@ class AtlasDataset(AtlasClass):
|
|
|
852
819
|
|
|
853
820
|
* **identifier** - The identifier for the dataset.
|
|
854
821
|
* **description** - A description for the dataset.
|
|
855
|
-
* **unique_id_field** -
|
|
822
|
+
* **unique_id_field** - A field that uniquely identifies each data point.
|
|
856
823
|
* **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
|
|
857
824
|
|
|
858
825
|
**Returns:** project_id on success.
|
|
@@ -865,15 +832,6 @@ class AtlasDataset(AtlasClass):
|
|
|
865
832
|
if "/" in identifier:
|
|
866
833
|
org_name = identifier.split("/")[0]
|
|
867
834
|
logger.info(f"Organization name: `{org_name}`")
|
|
868
|
-
# supported_modalities = ['text', 'embedding']
|
|
869
|
-
# if modality not in supported_modalities:
|
|
870
|
-
# msg = 'Tried to create dataset with modality: {}, but Atlas only supports: {}'.format(
|
|
871
|
-
# modality, supported_modalities
|
|
872
|
-
# )
|
|
873
|
-
# raise ValueError(msg)
|
|
874
|
-
|
|
875
|
-
if unique_id_field is None:
|
|
876
|
-
raise ValueError("You must specify a unique id field")
|
|
877
835
|
if description is None:
|
|
878
836
|
description = ""
|
|
879
837
|
response = requests.post(
|
|
@@ -884,7 +842,6 @@ class AtlasDataset(AtlasClass):
|
|
|
884
842
|
"project_name": project_slug,
|
|
885
843
|
"description": description,
|
|
886
844
|
"unique_id_field": unique_id_field,
|
|
887
|
-
# 'modality': modality,
|
|
888
845
|
"is_public": is_public,
|
|
889
846
|
},
|
|
890
847
|
)
|
|
@@ -939,11 +896,12 @@ class AtlasDataset(AtlasClass):
|
|
|
939
896
|
|
|
940
897
|
@property
|
|
941
898
|
def id(self) -> str:
|
|
942
|
-
"""The
|
|
899
|
+
"""The ID of the dataset."""
|
|
943
900
|
return self.meta["id"]
|
|
944
901
|
|
|
945
902
|
@property
|
|
946
903
|
def id_field(self) -> str:
|
|
904
|
+
"""The unique_id_field of the dataset."""
|
|
947
905
|
return self.meta["unique_id_field"]
|
|
948
906
|
|
|
949
907
|
@property
|
|
@@ -1147,7 +1105,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1147
1105
|
colorable_fields = []
|
|
1148
1106
|
|
|
1149
1107
|
for field in self.dataset_fields:
|
|
1150
|
-
if field not in [
|
|
1108
|
+
if field not in [indexed_field] and not field.startswith("_"):
|
|
1151
1109
|
colorable_fields.append(field)
|
|
1152
1110
|
|
|
1153
1111
|
build_template = {}
|
|
@@ -1410,98 +1368,126 @@ class AtlasDataset(AtlasClass):
|
|
|
1410
1368
|
Uploads blobs to the server and associates them with the data.
|
|
1411
1369
|
Blobs must reference objects stored locally
|
|
1412
1370
|
"""
|
|
1371
|
+
data_as_table: pa.Table
|
|
1413
1372
|
if isinstance(data, DataFrame):
|
|
1414
|
-
|
|
1373
|
+
data_as_table = pa.Table.from_pandas(data)
|
|
1415
1374
|
elif isinstance(data, list):
|
|
1416
|
-
|
|
1417
|
-
elif
|
|
1375
|
+
data_as_table = pa.Table.from_pylist(data)
|
|
1376
|
+
elif isinstance(data, pa.Table):
|
|
1377
|
+
data_as_table = data
|
|
1378
|
+
else:
|
|
1418
1379
|
raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.")
|
|
1419
1380
|
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
# add hash to data as _blob_hash
|
|
1425
|
-
# set indexed_field to _blob_hash
|
|
1426
|
-
# call _add_data
|
|
1427
|
-
|
|
1428
|
-
# Cast self id field to string for merged data lower down on function
|
|
1429
|
-
data = data.set_column( # type: ignore
|
|
1430
|
-
data.schema.get_field_index(self.id_field), self.id_field, pc.cast(data[self.id_field], pa.string()) # type: ignore
|
|
1431
|
-
)
|
|
1381
|
+
# Compute dataset length
|
|
1382
|
+
data_length = len(data_as_table)
|
|
1383
|
+
if data_length != len(blobs):
|
|
1384
|
+
raise ValueError(f"Number of data points ({data_length}) must match number of blobs ({len(blobs)})")
|
|
1432
1385
|
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
images = []
|
|
1439
|
-
for
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1386
|
+
TEMP_ID_COLUMN = "_nomic_internal_temp_id"
|
|
1387
|
+
temp_id_values = [str(i) for i in range(data_length)]
|
|
1388
|
+
data_as_table = data_as_table.append_column(TEMP_ID_COLUMN, pa.array(temp_id_values, type=pa.string()))
|
|
1389
|
+
blob_upload_endpoint = "/v1/project/data/add/blobs"
|
|
1390
|
+
actual_temp_ids = data_as_table[TEMP_ID_COLUMN].to_pylist()
|
|
1391
|
+
images = [] # List of (temp_id, image_bytes)
|
|
1392
|
+
for i in tqdm(range(data_length), desc="Processing images"):
|
|
1393
|
+
current_temp_id = actual_temp_ids[i]
|
|
1394
|
+
blob_item = blobs[i]
|
|
1395
|
+
|
|
1396
|
+
processed_blob_value = None
|
|
1397
|
+
if (isinstance(blob_item, str) or isinstance(blob_item, Path)) and os.path.exists(blob_item):
|
|
1398
|
+
image = Image.open(blob_item).convert("RGB")
|
|
1444
1399
|
if image.height > 512 or image.width > 512:
|
|
1445
1400
|
image = image.resize((512, 512))
|
|
1446
1401
|
buffered = BytesIO()
|
|
1447
1402
|
image.save(buffered, format="JPEG")
|
|
1448
|
-
|
|
1449
|
-
elif isinstance(
|
|
1450
|
-
|
|
1451
|
-
elif isinstance(
|
|
1452
|
-
|
|
1453
|
-
if
|
|
1454
|
-
|
|
1403
|
+
processed_blob_value = buffered.getvalue()
|
|
1404
|
+
elif isinstance(blob_item, bytes):
|
|
1405
|
+
processed_blob_value = blob_item
|
|
1406
|
+
elif isinstance(blob_item, Image.Image):
|
|
1407
|
+
img_pil = blob_item.convert("RGB") # Ensure it's PIL Image for methods
|
|
1408
|
+
if img_pil.height > 512 or img_pil.width > 512:
|
|
1409
|
+
img_pil = img_pil.resize((512, 512))
|
|
1455
1410
|
buffered = BytesIO()
|
|
1456
|
-
|
|
1457
|
-
|
|
1411
|
+
img_pil.save(buffered, format="JPEG")
|
|
1412
|
+
processed_blob_value = buffered.getvalue()
|
|
1458
1413
|
else:
|
|
1459
|
-
raise ValueError(
|
|
1414
|
+
raise ValueError(
|
|
1415
|
+
f"Invalid blob type for item at index {i} (temp_id: {current_temp_id}). Must be a path, bytes, or PIL Image. Got: {type(blob_item)}"
|
|
1416
|
+
)
|
|
1417
|
+
|
|
1418
|
+
if processed_blob_value is not None:
|
|
1419
|
+
images.append((current_temp_id, processed_blob_value))
|
|
1460
1420
|
|
|
1461
1421
|
batch_size = 40
|
|
1462
|
-
num_workers =
|
|
1422
|
+
num_workers = 2
|
|
1423
|
+
|
|
1424
|
+
def send_request(batch_start_index):
|
|
1425
|
+
image_batch = images[batch_start_index : batch_start_index + batch_size]
|
|
1426
|
+
temp_ids_in_batch = [item_id for item_id, _ in image_batch]
|
|
1427
|
+
blobs_for_api = [("blobs", blob_val) for _, blob_val in image_batch]
|
|
1463
1428
|
|
|
1464
|
-
def send_request(i):
|
|
1465
|
-
image_batch = images[i : i + batch_size]
|
|
1466
|
-
ids = [uuid for uuid, _ in image_batch]
|
|
1467
|
-
blobs = [("blobs", blob) for _, blob in image_batch]
|
|
1468
1429
|
response = requests.post(
|
|
1469
1430
|
self.atlas_api_path + blob_upload_endpoint,
|
|
1470
1431
|
headers=self.header,
|
|
1471
|
-
data={"dataset_id": self.id},
|
|
1472
|
-
files=
|
|
1432
|
+
data={"dataset_id": self.id}, # self.id is project_id
|
|
1433
|
+
files=blobs_for_api,
|
|
1473
1434
|
)
|
|
1474
1435
|
if response.status_code != 200:
|
|
1475
|
-
|
|
1476
|
-
|
|
1436
|
+
failed_ids_sample = temp_ids_in_batch[:5]
|
|
1437
|
+
logger.error(
|
|
1438
|
+
f"Blob upload request failed for batch starting with temp_ids: {failed_ids_sample}. Status: {response.status_code}, Response: {response.text}"
|
|
1439
|
+
)
|
|
1440
|
+
raise Exception(f"Blob upload failed: {response.text}")
|
|
1441
|
+
return {temp_id: blob_hash for temp_id, blob_hash in zip(temp_ids_in_batch, response.json()["hashes"])}
|
|
1477
1442
|
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1443
|
+
upload_pbar = pbar # Use passed-in pbar if available
|
|
1444
|
+
close_upload_pbar_locally = False
|
|
1445
|
+
if upload_pbar is None:
|
|
1446
|
+
upload_pbar = tqdm(total=len(images), desc="Uploading blobs to Atlas")
|
|
1447
|
+
close_upload_pbar_locally = True
|
|
1481
1448
|
|
|
1482
|
-
|
|
1483
|
-
returned_ids = []
|
|
1449
|
+
returned_temp_ids = []
|
|
1484
1450
|
returned_hashes = []
|
|
1451
|
+
succeeded_uploads = 0
|
|
1485
1452
|
|
|
1486
|
-
succeeded = 0
|
|
1487
1453
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
1488
|
-
futures = {executor.submit(send_request, i): i for i in range(0, len(
|
|
1454
|
+
futures = {executor.submit(send_request, i): i for i in range(0, len(images), batch_size)}
|
|
1489
1455
|
|
|
1490
1456
|
for future in concurrent.futures.as_completed(futures):
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1457
|
+
try:
|
|
1458
|
+
response_dict = future.result() # This is {temp_id: blob_hash}
|
|
1459
|
+
for temp_id, blob_hash_val in response_dict.items():
|
|
1460
|
+
returned_temp_ids.append(temp_id)
|
|
1461
|
+
returned_hashes.append(blob_hash_val)
|
|
1462
|
+
succeeded_uploads += len(response_dict)
|
|
1463
|
+
if upload_pbar:
|
|
1464
|
+
upload_pbar.update(len(response_dict))
|
|
1465
|
+
except Exception as e:
|
|
1466
|
+
logger.error(f"An error occurred during blob upload processing for a batch: {e}")
|
|
1467
|
+
# Optionally, collect failed batch info here if needed for partial success
|
|
1468
|
+
|
|
1469
|
+
if close_upload_pbar_locally and upload_pbar:
|
|
1470
|
+
upload_pbar.close()
|
|
1471
|
+
|
|
1472
|
+
hash_schema = pa.schema([(TEMP_ID_COLUMN, pa.string()), ("_blob_hash", pa.string())])
|
|
1473
|
+
|
|
1474
|
+
merged_data_as_table: pa.Table
|
|
1475
|
+
if succeeded_uploads > 0: # Only create hash_tb if there are successful uploads
|
|
1476
|
+
hash_tb = pa.Table.from_pydict(
|
|
1477
|
+
{TEMP_ID_COLUMN: returned_temp_ids, "_blob_hash": returned_hashes}, schema=hash_schema
|
|
1478
|
+
)
|
|
1479
|
+
merged_data_as_table = data_as_table.join(right_table=hash_tb, keys=TEMP_ID_COLUMN, join_type="left outer")
|
|
1480
|
+
else: # No successful uploads, so no hashes to merge, but keep original data structure
|
|
1481
|
+
# Need to ensure _blob_hash column is added with nulls, and id_field is present
|
|
1482
|
+
if "_blob_hash" not in data_as_table.column_names:
|
|
1483
|
+
data_as_table = data_as_table.append_column(
|
|
1484
|
+
"_blob_hash", pa.nulls(data_as_table.num_rows, type=pa.string())
|
|
1485
|
+
)
|
|
1486
|
+
merged_data_as_table = data_as_table
|
|
1500
1487
|
|
|
1501
|
-
|
|
1502
|
-
merged_data = data.join(right_table=hash_tb, keys=self.id_field) # type: ignore
|
|
1488
|
+
merged_data_as_table = merged_data_as_table.drop_columns([TEMP_ID_COLUMN])
|
|
1503
1489
|
|
|
1504
|
-
self._add_data(
|
|
1490
|
+
self._add_data(merged_data_as_table, pbar=pbar) # Pass original pbar argument
|
|
1505
1491
|
|
|
1506
1492
|
def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None):
|
|
1507
1493
|
"""
|
|
@@ -1580,12 +1566,8 @@ class AtlasDataset(AtlasClass):
|
|
|
1580
1566
|
None
|
|
1581
1567
|
"""
|
|
1582
1568
|
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
num_workers = 10
|
|
1586
|
-
|
|
1569
|
+
num_workers = 2
|
|
1587
1570
|
# Each worker currently is too slow beyond a shard_size of 10000
|
|
1588
|
-
|
|
1589
1571
|
# The heuristic here is: Never let shards be more than 10,000 items,
|
|
1590
1572
|
# OR more than 16MB uncompressed. Whichever is smaller.
|
|
1591
1573
|
|
|
@@ -1701,7 +1683,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1701
1683
|
else:
|
|
1702
1684
|
logger.info("Upload succeeded.")
|
|
1703
1685
|
|
|
1704
|
-
def update_maps(self, data: List[Dict], embeddings: Optional[np.ndarray] = None, num_workers: int =
|
|
1686
|
+
def update_maps(self, data: List[Dict], embeddings: Optional[np.ndarray] = None, num_workers: int = 2):
|
|
1705
1687
|
"""
|
|
1706
1688
|
Utility method to update a project's maps by adding the given data.
|
|
1707
1689
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|