nomic 3.5.0__tar.gz → 3.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nomic might be problematic. Click here for more details.
- {nomic-3.5.0 → nomic-3.5.1}/PKG-INFO +1 -1
- {nomic-3.5.0 → nomic-3.5.1}/nomic/atlas.py +3 -39
- {nomic-3.5.0 → nomic-3.5.1}/nomic/data_operations.py +337 -87
- {nomic-3.5.0 → nomic-3.5.1}/nomic/dataset.py +135 -123
- {nomic-3.5.0 → nomic-3.5.1}/nomic.egg-info/PKG-INFO +1 -1
- {nomic-3.5.0 → nomic-3.5.1}/setup.py +1 -1
- {nomic-3.5.0 → nomic-3.5.1}/README.md +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/__init__.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/aws/__init__.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/aws/sagemaker.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/cli.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/data_inference.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/embed.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/pl_callbacks/__init__.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/pl_callbacks/pl_callback.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/settings.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic/utils.py +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic.egg-info/SOURCES.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic.egg-info/dependency_links.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic.egg-info/entry_points.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic.egg-info/requires.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/nomic.egg-info/top_level.txt +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/pyproject.toml +0 -0
- {nomic-3.5.0 → nomic-3.5.1}/setup.cfg +0 -0
|
@@ -3,7 +3,6 @@ This class allows for programmatic interactions with Atlas - Nomic's neural data
|
|
|
3
3
|
or in a Jupyter Notebook to organize and interact with your unstructured data.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import uuid
|
|
7
6
|
from typing import Dict, Iterable, List, Optional, Union
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
@@ -42,7 +41,7 @@ def map_data(
|
|
|
42
41
|
embeddings: An [N,d] numpy array containing the N embeddings to add.
|
|
43
42
|
identifier: A name for your dataset that is used to generate the dataset identifier. A unique name will be chosen if not supplied.
|
|
44
43
|
description: The description of your dataset
|
|
45
|
-
id_field: Specify
|
|
44
|
+
id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
|
|
46
45
|
is_public: Should the dataset be accessible outside your Nomic Atlas organization.
|
|
47
46
|
indexed_field: The text field from the dataset that will be used to create embeddings, which determines the layout of the data map in Atlas. Required for text data but won't have an impact if uploading embeddings or image blobs.
|
|
48
47
|
projection: Options for configuring the 2D projection algorithm.
|
|
@@ -86,9 +85,6 @@ def map_data(
|
|
|
86
85
|
# default to vision v1.5
|
|
87
86
|
embedding_model = NomicEmbedOptions(model="nomic-embed-vision-v1.5")
|
|
88
87
|
|
|
89
|
-
if id_field is None:
|
|
90
|
-
id_field = ATLAS_DEFAULT_ID_FIELD
|
|
91
|
-
|
|
92
88
|
project_name = get_random_name()
|
|
93
89
|
|
|
94
90
|
dataset_name = project_name
|
|
@@ -100,38 +96,6 @@ def map_data(
|
|
|
100
96
|
if description:
|
|
101
97
|
description = description
|
|
102
98
|
|
|
103
|
-
# no metadata was specified
|
|
104
|
-
added_id_field = False
|
|
105
|
-
|
|
106
|
-
if data is None:
|
|
107
|
-
added_id_field = True
|
|
108
|
-
if embeddings is not None:
|
|
109
|
-
data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(embeddings))]
|
|
110
|
-
elif blobs is not None:
|
|
111
|
-
data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(blobs))]
|
|
112
|
-
else:
|
|
113
|
-
raise ValueError("You must specify either data, embeddings, or blobs")
|
|
114
|
-
|
|
115
|
-
if id_field == ATLAS_DEFAULT_ID_FIELD and data is not None:
|
|
116
|
-
if isinstance(data, list) and id_field not in data[0]:
|
|
117
|
-
added_id_field = True
|
|
118
|
-
for i in range(len(data)):
|
|
119
|
-
# do not modify object the user passed in - also ensures IDs are unique if two input datums are the same *object*
|
|
120
|
-
data[i] = data[i].copy()
|
|
121
|
-
data[i][id_field] = b64int(i)
|
|
122
|
-
elif isinstance(data, DataFrame) and id_field not in data.columns:
|
|
123
|
-
data[id_field] = [b64int(i) for i in range(data.shape[0])]
|
|
124
|
-
added_id_field = True
|
|
125
|
-
elif isinstance(data, pa.Table) and not id_field in data.column_names: # type: ignore
|
|
126
|
-
ids = pa.array([b64int(i) for i in range(len(data))])
|
|
127
|
-
data = data.append_column(id_field, ids) # type: ignore
|
|
128
|
-
added_id_field = True
|
|
129
|
-
elif id_field not in data[0]:
|
|
130
|
-
raise ValueError("map_data data must be a list of dicts, a pandas dataframe, or a pyarrow table")
|
|
131
|
-
|
|
132
|
-
if added_id_field:
|
|
133
|
-
logger.warning("An ID field was not specified in your data so one was generated for you in insertion order.")
|
|
134
|
-
|
|
135
99
|
dataset = AtlasDataset(
|
|
136
100
|
identifier=dataset_name, description=description, unique_id_field=id_field, is_public=is_public
|
|
137
101
|
)
|
|
@@ -202,7 +166,7 @@ def map_embeddings(
|
|
|
202
166
|
Args:
|
|
203
167
|
embeddings: An [N,d] numpy array containing the batch of N embeddings to add.
|
|
204
168
|
data: An [N,] element list of dictionaries containing metadata for each embedding.
|
|
205
|
-
id_field: Specify
|
|
169
|
+
id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
|
|
206
170
|
name: A name for your dataset. Specify in the format `organization/project` to create in a specific organization.
|
|
207
171
|
description: A description for your map.
|
|
208
172
|
is_public: Should this embedding map be public? Private maps can only be accessed by members of your organization.
|
|
@@ -250,7 +214,7 @@ def map_text(
|
|
|
250
214
|
Args:
|
|
251
215
|
data: An [N,] element iterable of dictionaries containing metadata for each embedding.
|
|
252
216
|
indexed_field: The name the data field containing the text your want to map.
|
|
253
|
-
id_field: Specify
|
|
217
|
+
id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
|
|
254
218
|
name: A name for your dataset. Specify in the format `organization/project` to create in a specific organization.
|
|
255
219
|
description: A description for your map.
|
|
256
220
|
build_topic_model: Builds a hierarchical topic model over your data to discover patterns.
|
|
@@ -25,7 +25,6 @@ class AtlasMapDuplicates:
|
|
|
25
25
|
|
|
26
26
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
27
27
|
self.projection = projection
|
|
28
|
-
self.id_field = self.projection.dataset.id_field
|
|
29
28
|
|
|
30
29
|
duplicate_columns = [
|
|
31
30
|
(field, sidecar)
|
|
@@ -40,7 +39,13 @@ class AtlasMapDuplicates:
|
|
|
40
39
|
|
|
41
40
|
self._duplicate_column = duplicate_columns[0]
|
|
42
41
|
self._cluster_column = cluster_columns[0]
|
|
42
|
+
self.duplicate_field = None
|
|
43
|
+
self.cluster_field = None
|
|
43
44
|
self._tb = None
|
|
45
|
+
self._has_unique_id_field = (
|
|
46
|
+
"unique_id_field" in self.projection.dataset.meta
|
|
47
|
+
and self.projection.dataset.meta["unique_id_field"] is not None
|
|
48
|
+
)
|
|
44
49
|
|
|
45
50
|
def _load_duplicates(self):
|
|
46
51
|
"""
|
|
@@ -51,30 +56,80 @@ class AtlasMapDuplicates:
|
|
|
51
56
|
self.duplicate_field = self._duplicate_column[0].lstrip("_")
|
|
52
57
|
self.cluster_field = self._cluster_column[0].lstrip("_")
|
|
53
58
|
logger.info("Loading duplicates")
|
|
59
|
+
id_field_name = "_position_index"
|
|
60
|
+
if self._has_unique_id_field:
|
|
61
|
+
id_field_name = self.projection.dataset.meta["unique_id_field"]
|
|
62
|
+
|
|
54
63
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
55
|
-
# Use datum id as root table
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
64
|
+
# Use datum id as root table if available, otherwise create synthetic IDs
|
|
65
|
+
if self._has_unique_id_field:
|
|
66
|
+
try:
|
|
67
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
68
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
69
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
70
|
+
# Create a synthetic ID table
|
|
71
|
+
tb = self._create_synthetic_id_table(key, duplicate_sidecar)
|
|
72
|
+
else:
|
|
73
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
74
|
+
tb = self._create_synthetic_id_table(key, duplicate_sidecar)
|
|
60
75
|
|
|
76
|
+
path = self.projection.tile_destination
|
|
61
77
|
if duplicate_sidecar == "":
|
|
62
78
|
path = path / Path(key).with_suffix(".feather")
|
|
63
79
|
else:
|
|
64
80
|
path = path / Path(key).with_suffix(f".{duplicate_sidecar}.feather")
|
|
65
81
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
82
|
+
try:
|
|
83
|
+
duplicate_tb = feather.read_table(path, memory_map=True)
|
|
84
|
+
for field in (self._duplicate_column[0], self._cluster_column[0]):
|
|
85
|
+
tb = tb.append_column(field, duplicate_tb[field])
|
|
86
|
+
tbs.append(tb)
|
|
87
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
88
|
+
logger.warning(f"Error loading duplicate data for key {key}: {e}")
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
if not tbs:
|
|
92
|
+
raise ValueError("No duplicate data could be loaded. Duplicate data files may be missing or corrupt.")
|
|
93
|
+
|
|
94
|
+
self._tb = pa.concat_tables(tbs).rename_columns([id_field_name, self.duplicate_field, self.cluster_field])
|
|
95
|
+
|
|
96
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
97
|
+
"""
|
|
98
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
99
|
+
or when unique_id_field isn't specified.
|
|
100
|
+
"""
|
|
101
|
+
# Try to determine the size of the table by loading the duplicate sidecar
|
|
102
|
+
path = self.projection.tile_destination
|
|
103
|
+
if sidecar == "":
|
|
104
|
+
path = path / Path(key).with_suffix(".feather")
|
|
105
|
+
else:
|
|
106
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
sidecar_tb = feather.read_table(path, memory_map=True)
|
|
110
|
+
size = len(sidecar_tb)
|
|
111
|
+
# Create a table with position indices as IDs
|
|
112
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
113
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
116
|
+
# Return an empty table as fallback
|
|
117
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
71
118
|
|
|
72
119
|
def _download_duplicates(self):
|
|
73
120
|
"""
|
|
74
121
|
Downloads the feather tree for duplicates.
|
|
75
122
|
"""
|
|
76
123
|
logger.info("Downloading duplicates")
|
|
77
|
-
|
|
124
|
+
|
|
125
|
+
# Only download datum_id if we have a unique ID field
|
|
126
|
+
if self._has_unique_id_field:
|
|
127
|
+
try:
|
|
128
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
129
|
+
except ValueError as e:
|
|
130
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
131
|
+
self._has_unique_id_field = False
|
|
132
|
+
|
|
78
133
|
assert self._cluster_column[1] == self._duplicate_column[1], "Cluster and duplicate should be in same sidecar"
|
|
79
134
|
self.projection._download_sidecar(self._duplicate_column[1], overwrite=False)
|
|
80
135
|
|
|
@@ -100,17 +155,24 @@ class AtlasMapDuplicates:
|
|
|
100
155
|
|
|
101
156
|
def deletion_candidates(self) -> List[str]:
|
|
102
157
|
"""
|
|
103
|
-
|
|
104
158
|
Returns:
|
|
105
159
|
The ids for all data points which are semantic duplicates and are candidates for being deleted from the dataset. If you remove these data points from your dataset, your dataset will be semantically deduplicated.
|
|
106
160
|
"""
|
|
107
|
-
|
|
161
|
+
id_field_name = "_position_index"
|
|
162
|
+
if self._has_unique_id_field:
|
|
163
|
+
id_field_name = self.projection.dataset.meta["unique_id_field"]
|
|
164
|
+
|
|
165
|
+
dupes = self.tb[id_field_name].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
|
|
108
166
|
return dupes.to_pylist()
|
|
109
167
|
|
|
110
168
|
def __repr__(self) -> str:
|
|
169
|
+
id_field_name = "_position_index"
|
|
170
|
+
if self._has_unique_id_field:
|
|
171
|
+
id_field_name = self.projection.dataset.meta["unique_id_field"]
|
|
172
|
+
|
|
111
173
|
repr = f"===Atlas Duplicates for ({self.projection})\n"
|
|
112
174
|
duplicate_count = len(
|
|
113
|
-
self.tb[
|
|
175
|
+
self.tb[id_field_name].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
|
|
114
176
|
)
|
|
115
177
|
cluster_count = len(self.tb[self.cluster_field].value_counts())
|
|
116
178
|
repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n"
|
|
@@ -125,7 +187,6 @@ class AtlasMapTopics:
|
|
|
125
187
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
126
188
|
self.projection = projection
|
|
127
189
|
self.dataset = projection.dataset
|
|
128
|
-
self.id_field = self.projection.dataset.id_field
|
|
129
190
|
self._metadata = None
|
|
130
191
|
self._hierarchy = None
|
|
131
192
|
self._topic_columns = [
|
|
@@ -134,6 +195,9 @@ class AtlasMapTopics:
|
|
|
134
195
|
assert len(self._topic_columns) > 0, "Topic modeling has not yet been run on this map."
|
|
135
196
|
self.depth = len(self._topic_columns)
|
|
136
197
|
self._tb = None
|
|
198
|
+
self._has_unique_id_field = (
|
|
199
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
200
|
+
)
|
|
137
201
|
|
|
138
202
|
def _load_topics(self):
|
|
139
203
|
"""
|
|
@@ -149,45 +213,96 @@ class AtlasMapTopics:
|
|
|
149
213
|
# Should just be one sidecar
|
|
150
214
|
topic_sidecar = set([sidecar for _, sidecar in self._topic_columns]).pop()
|
|
151
215
|
logger.info("Loading topics")
|
|
216
|
+
id_field_name = "_position_index"
|
|
217
|
+
if self._has_unique_id_field:
|
|
218
|
+
id_field_name = self.dataset.meta["unique_id_field"]
|
|
219
|
+
|
|
152
220
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
153
|
-
# Use datum id as root table
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
221
|
+
# Use datum id as root table if available, otherwise create a synthetic index
|
|
222
|
+
if self._has_unique_id_field:
|
|
223
|
+
try:
|
|
224
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
225
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
226
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
227
|
+
# If datum_id file doesn't exist, create a table with a position index
|
|
228
|
+
logger.warning(f"Datum ID file not found for key {key}. Creating synthetic IDs.")
|
|
229
|
+
tb = self._create_synthetic_id_table(key, topic_sidecar)
|
|
230
|
+
else:
|
|
231
|
+
# Create a table with a position index when no unique_id_field is available
|
|
232
|
+
tb = self._create_synthetic_id_table(key, topic_sidecar)
|
|
233
|
+
|
|
157
234
|
path = self.projection.tile_destination
|
|
158
235
|
if topic_sidecar == "":
|
|
159
236
|
path = path / Path(key).with_suffix(".feather")
|
|
160
237
|
else:
|
|
161
238
|
path = path / Path(key).with_suffix(f".{topic_sidecar}.feather")
|
|
162
239
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
topic_ids_to_label
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
240
|
+
try:
|
|
241
|
+
topic_tb = feather.read_table(path, memory_map=True)
|
|
242
|
+
# Do this in depth order
|
|
243
|
+
for d in range(1, self.depth + 1):
|
|
244
|
+
column = f"_topic_depth_{d}"
|
|
245
|
+
if integer_topics:
|
|
246
|
+
column = f"_topic_depth_{d}_int"
|
|
247
|
+
topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
|
|
248
|
+
assert label_df is not None
|
|
249
|
+
topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
|
|
250
|
+
topic_ids_to_label, on="topic_id", how="right"
|
|
251
|
+
)
|
|
252
|
+
new_column = f"_topic_depth_{d}"
|
|
253
|
+
tb = tb.append_column(
|
|
254
|
+
new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
|
|
258
|
+
tbs.append(tb)
|
|
259
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
260
|
+
logger.warning(f"Error loading topic data for key {key}: {e}")
|
|
261
|
+
continue
|
|
181
262
|
|
|
182
|
-
|
|
263
|
+
if not tbs:
|
|
264
|
+
raise ValueError("No topic data could be loaded. Topic data files may be missing or corrupt.")
|
|
265
|
+
|
|
266
|
+
renamed_columns = [id_field_name] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
|
|
183
267
|
self._tb = pa.concat_tables(tbs).rename_columns(renamed_columns)
|
|
184
268
|
|
|
269
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
270
|
+
"""
|
|
271
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
272
|
+
or when unique_id_field isn't specified.
|
|
273
|
+
"""
|
|
274
|
+
# Try to determine the size of the table by loading the topic sidecar
|
|
275
|
+
path = self.projection.tile_destination
|
|
276
|
+
if sidecar == "":
|
|
277
|
+
path = path / Path(key).with_suffix(".feather")
|
|
278
|
+
else:
|
|
279
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
280
|
+
|
|
281
|
+
try:
|
|
282
|
+
topic_tb = feather.read_table(path, memory_map=True)
|
|
283
|
+
size = len(topic_tb)
|
|
284
|
+
# Create a table with position indices as IDs
|
|
285
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
286
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
289
|
+
# Return an empty table as fallback
|
|
290
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
291
|
+
|
|
185
292
|
def _download_topics(self):
|
|
186
293
|
"""
|
|
187
294
|
Downloads the feather tree for topics.
|
|
188
295
|
"""
|
|
189
296
|
logger.info("Downloading topics")
|
|
190
|
-
|
|
297
|
+
|
|
298
|
+
# Only download datum_id if we have a unique ID field
|
|
299
|
+
if self._has_unique_id_field:
|
|
300
|
+
try:
|
|
301
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
302
|
+
except ValueError as e:
|
|
303
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
304
|
+
self._has_unique_id_field = False
|
|
305
|
+
|
|
191
306
|
topic_sidecars = set([sidecar for _, sidecar in self._topic_columns])
|
|
192
307
|
assert len(topic_sidecars) == 1, "Multiple topic sidecars found."
|
|
193
308
|
self.projection._download_sidecar(topic_sidecars.pop(), overwrite=False)
|
|
@@ -286,7 +401,10 @@ class AtlasMapTopics:
|
|
|
286
401
|
raise ValueError("Topic depth out of range.")
|
|
287
402
|
|
|
288
403
|
# Unique datum id column to aggregate
|
|
289
|
-
datum_id_col =
|
|
404
|
+
datum_id_col = "_position_index"
|
|
405
|
+
if "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None:
|
|
406
|
+
datum_id_col = self.dataset.meta["unique_id_field"]
|
|
407
|
+
|
|
290
408
|
df = self.df
|
|
291
409
|
|
|
292
410
|
topic_datum_dict = df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict()
|
|
@@ -333,8 +451,12 @@ class AtlasMapTopics:
|
|
|
333
451
|
A list of `{topic, count}` dictionaries, sorted from largest count to smallest count.
|
|
334
452
|
"""
|
|
335
453
|
data = AtlasMapData(self.projection, fields=[time_field])
|
|
336
|
-
|
|
337
|
-
|
|
454
|
+
id_field_name = "_position_index"
|
|
455
|
+
if "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None:
|
|
456
|
+
id_field_name = self.dataset.meta["unique_id_field"]
|
|
457
|
+
|
|
458
|
+
time_data = data.tb.select([id_field_name, time_field])
|
|
459
|
+
merged_tb = self.tb.join(time_data, id_field_name, join_type="inner").combine_chunks()
|
|
338
460
|
|
|
339
461
|
del time_data # free up memory
|
|
340
462
|
|
|
@@ -343,12 +465,12 @@ class AtlasMapTopics:
|
|
|
343
465
|
topic_densities = {}
|
|
344
466
|
for depth in range(1, self.depth + 1):
|
|
345
467
|
topic_column = f"topic_depth_{depth}"
|
|
346
|
-
topic_counts = merged_tb.group_by(topic_column).aggregate([(
|
|
468
|
+
topic_counts = merged_tb.group_by(topic_column).aggregate([(id_field_name, "count")]).to_pandas()
|
|
347
469
|
for _, row in topic_counts.iterrows():
|
|
348
470
|
topic = row[topic_column]
|
|
349
471
|
if topic not in topic_densities:
|
|
350
472
|
topic_densities[topic] = 0
|
|
351
|
-
topic_densities[topic] += row[
|
|
473
|
+
topic_densities[topic] += row[id_field_name + "_count"]
|
|
352
474
|
return topic_densities
|
|
353
475
|
|
|
354
476
|
def vector_search_topics(self, queries: np.ndarray, k: int = 32, depth: int = 3) -> Dict:
|
|
@@ -448,10 +570,12 @@ class AtlasMapEmbeddings:
|
|
|
448
570
|
|
|
449
571
|
def __init__(self, projection: "AtlasProjection"): # type: ignore
|
|
450
572
|
self.projection = projection
|
|
451
|
-
self.id_field = self.projection.dataset.id_field
|
|
452
573
|
self.dataset = projection.dataset
|
|
453
574
|
self._tb: pa.Table = None
|
|
454
575
|
self._latent = None
|
|
576
|
+
self._has_unique_id_field = (
|
|
577
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
578
|
+
)
|
|
455
579
|
|
|
456
580
|
@property
|
|
457
581
|
def df(self):
|
|
@@ -481,23 +605,81 @@ class AtlasMapEmbeddings:
|
|
|
481
605
|
tbs = []
|
|
482
606
|
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
483
607
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
484
|
-
# Use datum id as root table
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
608
|
+
# Use datum id as root table if available, otherwise create synthetic IDs
|
|
609
|
+
if self._has_unique_id_field:
|
|
610
|
+
try:
|
|
611
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
612
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
613
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
614
|
+
# Create a synthetic ID table
|
|
615
|
+
tb = self._create_synthetic_id_table(key, coord_sidecar)
|
|
616
|
+
else:
|
|
617
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
618
|
+
tb = self._create_synthetic_id_table(key, coord_sidecar)
|
|
619
|
+
|
|
488
620
|
path = self.projection.tile_destination
|
|
489
621
|
if coord_sidecar == "":
|
|
490
622
|
path = path / Path(key).with_suffix(".feather")
|
|
491
623
|
else:
|
|
492
624
|
path = path / Path(key).with_suffix(f".{coord_sidecar}.feather")
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
625
|
+
|
|
626
|
+
try:
|
|
627
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
628
|
+
for col in carfile.column_names:
|
|
629
|
+
if col in ["x", "y"]:
|
|
630
|
+
tb = tb.append_column(col, carfile[col])
|
|
631
|
+
tbs.append(tb)
|
|
632
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
633
|
+
logger.warning(f"Error loading embedding data for key {key}: {e}")
|
|
634
|
+
continue
|
|
635
|
+
|
|
636
|
+
if not tbs:
|
|
637
|
+
raise ValueError("No embedding data could be loaded. Embedding data files may be missing or corrupt.")
|
|
638
|
+
|
|
498
639
|
self._tb = pa.concat_tables(tbs)
|
|
499
640
|
return self._tb
|
|
500
641
|
|
|
642
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
643
|
+
"""
|
|
644
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
645
|
+
or when unique_id_field isn't specified.
|
|
646
|
+
"""
|
|
647
|
+
# Try to determine the size of the table by loading the coordinate sidecar
|
|
648
|
+
path = self.projection.tile_destination
|
|
649
|
+
if sidecar == "":
|
|
650
|
+
path = path / Path(key).with_suffix(".feather")
|
|
651
|
+
else:
|
|
652
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
653
|
+
|
|
654
|
+
try:
|
|
655
|
+
coord_tb = feather.read_table(path, memory_map=True)
|
|
656
|
+
size = len(coord_tb)
|
|
657
|
+
# Create a table with position indices as IDs
|
|
658
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
659
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
660
|
+
except Exception as e:
|
|
661
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
662
|
+
# Return an empty table as fallback
|
|
663
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
664
|
+
|
|
665
|
+
def _download_projected(self) -> List[Path]:
|
|
666
|
+
"""
|
|
667
|
+
Downloads the feather tree for projection coordinates.
|
|
668
|
+
"""
|
|
669
|
+
logger.info("Downloading projected embeddings")
|
|
670
|
+
# Note that y coord should be in same sidecar
|
|
671
|
+
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
672
|
+
|
|
673
|
+
# Only download datum_id if we have a unique ID field
|
|
674
|
+
if self._has_unique_id_field:
|
|
675
|
+
try:
|
|
676
|
+
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
677
|
+
except ValueError as e:
|
|
678
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
679
|
+
self._has_unique_id_field = False
|
|
680
|
+
|
|
681
|
+
return self.projection._download_sidecar(coord_sidecar, overwrite=False)
|
|
682
|
+
|
|
501
683
|
@property
|
|
502
684
|
def projected(self) -> pd.DataFrame:
|
|
503
685
|
"""
|
|
@@ -531,16 +713,6 @@ class AtlasMapEmbeddings:
|
|
|
531
713
|
all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
|
|
532
714
|
return np.vstack(all_embeddings)
|
|
533
715
|
|
|
534
|
-
def _download_projected(self) -> List[Path]:
|
|
535
|
-
"""
|
|
536
|
-
Downloads the feather tree for projection coordinates.
|
|
537
|
-
"""
|
|
538
|
-
logger.info("Downloading projected embeddings")
|
|
539
|
-
# Note that y coord should be in same sidecar
|
|
540
|
-
coord_sidecar = self.projection._get_sidecar_from_field("x")
|
|
541
|
-
self.projection._download_sidecar("datum_id", overwrite=False)
|
|
542
|
-
return self.projection._download_sidecar(coord_sidecar, overwrite=False)
|
|
543
|
-
|
|
544
716
|
def _download_latent(self) -> List[Path]:
|
|
545
717
|
"""
|
|
546
718
|
Downloads the feather tree for embeddings.
|
|
@@ -670,12 +842,17 @@ class AtlasMapTags:
|
|
|
670
842
|
def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False): # type: ignore
|
|
671
843
|
self.projection = projection
|
|
672
844
|
self.dataset = projection.dataset
|
|
673
|
-
self.id_field = self.projection.dataset.id_field
|
|
674
845
|
# Pre-fetch datum ids first upon initialization
|
|
675
|
-
|
|
676
|
-
self.
|
|
677
|
-
|
|
678
|
-
|
|
846
|
+
self._has_unique_id_field = (
|
|
847
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
if self._has_unique_id_field:
|
|
851
|
+
try:
|
|
852
|
+
self.projection._download_sidecar("datum_id")
|
|
853
|
+
except Exception as e:
|
|
854
|
+
logger.warning(f"Failed to fetch datum ids: {e}. Will use synthetic IDs.")
|
|
855
|
+
self._has_unique_id_field = False
|
|
679
856
|
self.auto_cleanup = auto_cleanup
|
|
680
857
|
|
|
681
858
|
@property
|
|
@@ -748,20 +925,38 @@ class AtlasMapTags:
|
|
|
748
925
|
"""
|
|
749
926
|
tag_paths = self._download_tag(tag_name, overwrite=overwrite)
|
|
750
927
|
datum_ids = []
|
|
928
|
+
id_field_name = "_position_index"
|
|
929
|
+
if self._has_unique_id_field:
|
|
930
|
+
id_field_name = self.dataset.meta["unique_id_field"]
|
|
931
|
+
|
|
751
932
|
for path in tag_paths:
|
|
752
933
|
tb = feather.read_table(path)
|
|
753
934
|
last_coord = path.name.split(".")[0]
|
|
754
|
-
|
|
755
|
-
|
|
935
|
+
|
|
936
|
+
# Get ID information - either from datum_id file or create synthetic IDs
|
|
937
|
+
if self._has_unique_id_field:
|
|
938
|
+
try:
|
|
939
|
+
tile_path = path.with_name(last_coord + ".datum_id.feather")
|
|
940
|
+
tile_tb = feather.read_table(tile_path).select([id_field_name])
|
|
941
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
942
|
+
# Create synthetic IDs if datum_id file not found
|
|
943
|
+
size = len(tb)
|
|
944
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
945
|
+
tile_tb = pa.Table.from_arrays([pa.array(position_indices)], names=[id_field_name])
|
|
946
|
+
else:
|
|
947
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
948
|
+
size = len(tb)
|
|
949
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
950
|
+
tile_tb = pa.Table.from_arrays([pa.array(position_indices)], names=[id_field_name])
|
|
756
951
|
|
|
757
952
|
if "all_set" in tb.column_names:
|
|
758
953
|
if tb["all_set"][0].as_py() == True:
|
|
759
|
-
datum_ids.extend(tile_tb[
|
|
954
|
+
datum_ids.extend(tile_tb[id_field_name].to_pylist())
|
|
760
955
|
else:
|
|
761
956
|
# filter on rows
|
|
762
957
|
try:
|
|
763
|
-
tb = tb.append_column(
|
|
764
|
-
datum_ids.extend(tb.filter(pc.field("bitmask") == True)[
|
|
958
|
+
tb = tb.append_column(id_field_name, tile_tb[id_field_name])
|
|
959
|
+
datum_ids.extend(tb.filter(pc.field("bitmask") == True)[id_field_name].to_pylist())
|
|
765
960
|
except Exception as e:
|
|
766
961
|
raise Exception(f"Failed to fetch datums in tag. {e}")
|
|
767
962
|
return datum_ids
|
|
@@ -849,7 +1044,6 @@ class AtlasMapData:
|
|
|
849
1044
|
def __init__(self, projection: "AtlasProjection", fields=None): # type: ignore
|
|
850
1045
|
self.projection = projection
|
|
851
1046
|
self.dataset = projection.dataset
|
|
852
|
-
self.id_field = self.projection.dataset.id_field
|
|
853
1047
|
if fields is None:
|
|
854
1048
|
# TODO: fall back on something more reliable here
|
|
855
1049
|
self.fields = self.dataset.dataset_fields
|
|
@@ -858,6 +1052,9 @@ class AtlasMapData:
|
|
|
858
1052
|
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
|
|
859
1053
|
self.fields = fields
|
|
860
1054
|
self._tb = None
|
|
1055
|
+
self._has_unique_id_field = (
|
|
1056
|
+
"unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
|
|
1057
|
+
)
|
|
861
1058
|
|
|
862
1059
|
def _load_data(self, data_columns: List[Tuple[str, str]]):
|
|
863
1060
|
"""
|
|
@@ -871,24 +1068,67 @@ class AtlasMapData:
|
|
|
871
1068
|
sidecars_to_load = set([sidecar for _, sidecar in data_columns if sidecar != "datum_id"])
|
|
872
1069
|
logger.info("Loading data")
|
|
873
1070
|
for key in tqdm(self.projection._manifest["key"].to_pylist()):
|
|
874
|
-
# Use datum id as root table
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
1071
|
+
# Use datum id as root table if available, otherwise create synthetic IDs
|
|
1072
|
+
if self._has_unique_id_field:
|
|
1073
|
+
try:
|
|
1074
|
+
datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
|
|
1075
|
+
tb = feather.read_table(datum_id_path, memory_map=True)
|
|
1076
|
+
except (FileNotFoundError, pa.ArrowInvalid):
|
|
1077
|
+
# Create a synthetic ID table
|
|
1078
|
+
# Using the first sidecar to determine table size
|
|
1079
|
+
first_sidecar = next(iter(sidecars_to_load)) if sidecars_to_load else ""
|
|
1080
|
+
tb = self._create_synthetic_id_table(key, first_sidecar)
|
|
1081
|
+
else:
|
|
1082
|
+
# Create synthetic IDs when no unique_id_field is available
|
|
1083
|
+
first_sidecar = next(iter(sidecars_to_load)) if sidecars_to_load else ""
|
|
1084
|
+
tb = self._create_synthetic_id_table(key, first_sidecar)
|
|
1085
|
+
|
|
878
1086
|
for sidecar in sidecars_to_load:
|
|
879
1087
|
path = self.projection.tile_destination
|
|
880
1088
|
if sidecar == "":
|
|
881
1089
|
path = path / Path(key).with_suffix(".feather")
|
|
882
1090
|
else:
|
|
883
1091
|
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
1092
|
+
|
|
1093
|
+
try:
|
|
1094
|
+
carfile = feather.read_table(path, memory_map=True)
|
|
1095
|
+
for col in carfile.column_names:
|
|
1096
|
+
if col in self.fields:
|
|
1097
|
+
tb = tb.append_column(col, carfile[col])
|
|
1098
|
+
except (FileNotFoundError, pa.ArrowInvalid) as e:
|
|
1099
|
+
logger.warning(f"Error loading data for key {key}, sidecar {sidecar}: {e}")
|
|
1100
|
+
continue
|
|
1101
|
+
|
|
888
1102
|
tbs.append(tb)
|
|
889
1103
|
|
|
1104
|
+
if not tbs:
|
|
1105
|
+
raise ValueError("No data could be loaded. Data files may be missing or corrupt.")
|
|
1106
|
+
|
|
890
1107
|
self._tb = pa.concat_tables(tbs)
|
|
891
1108
|
|
|
1109
|
+
def _create_synthetic_id_table(self, key, sidecar):
|
|
1110
|
+
"""
|
|
1111
|
+
Create a synthetic table with position indices when datum_id file isn't available
|
|
1112
|
+
or when unique_id_field isn't specified.
|
|
1113
|
+
"""
|
|
1114
|
+
# Try to determine the size of the table by loading a sidecar file
|
|
1115
|
+
path = self.projection.tile_destination
|
|
1116
|
+
if sidecar == "":
|
|
1117
|
+
path = path / Path(key).with_suffix(".feather")
|
|
1118
|
+
else:
|
|
1119
|
+
path = path / Path(key).with_suffix(f".{sidecar}.feather")
|
|
1120
|
+
|
|
1121
|
+
try:
|
|
1122
|
+
sidecar_tb = feather.read_table(path, memory_map=True)
|
|
1123
|
+
size = len(sidecar_tb)
|
|
1124
|
+
# Create a table with position indices as IDs
|
|
1125
|
+
position_indices = [f"pos_{i}" for i in range(size)]
|
|
1126
|
+
return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
|
|
1127
|
+
except Exception as e:
|
|
1128
|
+
logger.error(f"Failed to create synthetic IDs for {key}: {e}")
|
|
1129
|
+
# Return an empty table as fallback
|
|
1130
|
+
return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
|
|
1131
|
+
|
|
892
1132
|
def _download_data(self, fields: Optional[List[str]] = None) -> List[Tuple[str, str]]:
|
|
893
1133
|
"""
|
|
894
1134
|
Downloads the feather tree for user uploaded data.
|
|
@@ -902,16 +1142,26 @@ class AtlasMapData:
|
|
|
902
1142
|
logger.info("Downloading data")
|
|
903
1143
|
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
904
1144
|
|
|
905
|
-
# Download specified or all sidecar fields + always download datum_id
|
|
1145
|
+
# Download specified or all sidecar fields + always download datum_id if available
|
|
906
1146
|
data_columns_to_load = [
|
|
907
1147
|
(str(field), str(sidecar))
|
|
908
1148
|
for field, sidecar in self.projection._registered_columns
|
|
909
1149
|
if field[0] != "_" and ((field in fields) or sidecar == "datum_id")
|
|
910
1150
|
]
|
|
911
1151
|
|
|
912
|
-
#
|
|
913
|
-
|
|
1152
|
+
# Only download datum_id if we have a unique ID field
|
|
1153
|
+
if self._has_unique_id_field:
|
|
1154
|
+
try:
|
|
1155
|
+
self.projection._download_sidecar("datum_id")
|
|
1156
|
+
except ValueError as e:
|
|
1157
|
+
logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
|
|
1158
|
+
self._has_unique_id_field = False
|
|
1159
|
+
|
|
1160
|
+
# Download all required sidecars for the fields
|
|
1161
|
+
sidecars_to_download = set(sidecar for _, sidecar in data_columns_to_load if sidecar != "datum_id")
|
|
1162
|
+
for sidecar in sidecars_to_download:
|
|
914
1163
|
self.projection._download_sidecar(sidecar)
|
|
1164
|
+
|
|
915
1165
|
return data_columns_to_load
|
|
916
1166
|
|
|
917
1167
|
@property
|
|
@@ -115,15 +115,12 @@ class AtlasClass(object):
|
|
|
115
115
|
|
|
116
116
|
return response.json()
|
|
117
117
|
|
|
118
|
-
def _validate_map_data_inputs(self, colorable_fields,
|
|
118
|
+
def _validate_map_data_inputs(self, colorable_fields, data_sample):
|
|
119
119
|
"""Validates inputs to map data calls."""
|
|
120
120
|
|
|
121
121
|
if not isinstance(colorable_fields, list):
|
|
122
122
|
raise ValueError("colorable_fields must be a list of fields")
|
|
123
123
|
|
|
124
|
-
if id_field in colorable_fields:
|
|
125
|
-
raise Exception(f"Cannot color by unique id field: {id_field}")
|
|
126
|
-
|
|
127
124
|
for field in colorable_fields:
|
|
128
125
|
if field not in data_sample:
|
|
129
126
|
raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.")
|
|
@@ -274,14 +271,12 @@ class AtlasClass(object):
|
|
|
274
271
|
"""
|
|
275
272
|
Private method. validates upload data against the dataset arrow schema, and associated other checks.
|
|
276
273
|
|
|
277
|
-
1. If unique_id_field is specified, validates that each datum has that field. If not, adds it and then notifies the user that it was added.
|
|
278
|
-
|
|
279
274
|
Args:
|
|
280
275
|
data: an arrow table.
|
|
281
276
|
project: the atlas dataset you are validating the data for.
|
|
282
277
|
|
|
283
278
|
Returns:
|
|
284
|
-
|
|
279
|
+
Validated pyarrow table.
|
|
285
280
|
"""
|
|
286
281
|
if not isinstance(data, pa.Table):
|
|
287
282
|
raise Exception("Invalid data type for upload: {}".format(type(data)))
|
|
@@ -295,8 +290,32 @@ class AtlasClass(object):
|
|
|
295
290
|
msg = "Must include embeddings in embedding dataset upload."
|
|
296
291
|
raise ValueError(msg)
|
|
297
292
|
|
|
298
|
-
|
|
299
|
-
|
|
293
|
+
# Check and validate ID field if specified
|
|
294
|
+
if "unique_id_field" in project.meta and project.meta["unique_id_field"] is not None:
|
|
295
|
+
id_field = project.meta["unique_id_field"]
|
|
296
|
+
|
|
297
|
+
# Check if ID field exists in data
|
|
298
|
+
if id_field not in data.column_names:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"Data must contain the ID column `{id_field}` as specified in dataset's unique_id_field"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Check for null values in ID field
|
|
304
|
+
if data[id_field].null_count > 0:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"As your unique id field, {id_field} must not contain null values, but {data[id_field].null_count} found."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Check ID field length (36 characters max)
|
|
310
|
+
if pa.types.is_string(data[id_field].type):
|
|
311
|
+
# Use a safer alternative to check string length
|
|
312
|
+
utf8_length_values = pc.utf8_length(data[id_field]) # type: ignore
|
|
313
|
+
max_length_scalar = pc.max(utf8_length_values) # type: ignore
|
|
314
|
+
max_length = max_length_scalar.as_py()
|
|
315
|
+
if max_length > 36:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"The id_field contains values greater than 36 characters. Atlas does not support id_fields longer than 36 characters."
|
|
318
|
+
)
|
|
300
319
|
|
|
301
320
|
seen = set()
|
|
302
321
|
for col in data.column_names:
|
|
@@ -313,11 +332,6 @@ class AtlasClass(object):
|
|
|
313
332
|
# filling in nulls, etc.
|
|
314
333
|
reformatted = {}
|
|
315
334
|
|
|
316
|
-
if data[project.id_field].null_count > 0:
|
|
317
|
-
raise ValueError(
|
|
318
|
-
f"{project.id_field} must not contain null values, but {data[project.id_field].null_count} found."
|
|
319
|
-
)
|
|
320
|
-
|
|
321
335
|
assert project.schema is not None, "Project schema not found."
|
|
322
336
|
|
|
323
337
|
for field in project.schema:
|
|
@@ -335,10 +349,15 @@ class AtlasClass(object):
|
|
|
335
349
|
f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version."
|
|
336
350
|
)
|
|
337
351
|
reformatted[field.name] = pc.fill_null(reformatted[field.name], "null")
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
352
|
+
|
|
353
|
+
# Check for empty strings and replace with "null"
|
|
354
|
+
# Separate the operations for better type checking
|
|
355
|
+
binary_length_values = pc.binary_length(reformatted[field.name]) # type: ignore
|
|
356
|
+
has_empty_strings = pc.equal(binary_length_values, 0) # type: ignore
|
|
357
|
+
if pc.any(has_empty_strings).as_py(): # type: ignore
|
|
358
|
+
mask = has_empty_strings.combine_chunks()
|
|
359
|
+
assert pa.types.is_boolean(mask.type)
|
|
360
|
+
reformatted[field.name] = pc.replace_with_mask(reformatted[field.name], mask, "null") # type: ignore
|
|
342
361
|
for field in data.schema:
|
|
343
362
|
if not field.name in reformatted:
|
|
344
363
|
if field.name == "_embeddings":
|
|
@@ -350,27 +369,12 @@ class AtlasClass(object):
|
|
|
350
369
|
if project.meta["insert_update_delete_lock"]:
|
|
351
370
|
raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.")
|
|
352
371
|
|
|
353
|
-
# The following two conditions should never occur given the above, but just in case...
|
|
354
|
-
assert project.id_field in data.column_names, f"Upload does not contain your specified id_field"
|
|
355
|
-
|
|
356
|
-
if not pa.types.is_string(data[project.id_field].type):
|
|
357
|
-
logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}")
|
|
358
|
-
data = data.drop([project.id_field]).append_column(
|
|
359
|
-
project.id_field, data[project.id_field].cast(pa.string())
|
|
360
|
-
)
|
|
361
|
-
|
|
362
372
|
for key in data.column_names:
|
|
363
373
|
if key.startswith("_"):
|
|
364
374
|
if key == "_embeddings" or key == "_blob_hash":
|
|
365
375
|
continue
|
|
366
376
|
raise ValueError("Metadata fields cannot start with _")
|
|
367
|
-
|
|
368
|
-
first_match = data.filter(
|
|
369
|
-
pa.compute.greater(pa.compute.utf8_length(data[project.id_field]), 36) # type: ignore
|
|
370
|
-
).to_pylist()[0][project.id_field]
|
|
371
|
-
raise ValueError(
|
|
372
|
-
f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters."
|
|
373
|
-
)
|
|
377
|
+
|
|
374
378
|
return data
|
|
375
379
|
|
|
376
380
|
def _get_organization(self, organization_slug=None, organization_id=None) -> Tuple[str, str]:
|
|
@@ -696,10 +700,6 @@ class AtlasProjection:
|
|
|
696
700
|
def tile_destination(self):
|
|
697
701
|
return Path("~/.nomic/cache", self.id).expanduser()
|
|
698
702
|
|
|
699
|
-
@property
|
|
700
|
-
def datum_id_field(self):
|
|
701
|
-
return self.dataset.meta["unique_id_field"]
|
|
702
|
-
|
|
703
703
|
def _get_atoms(self, ids: List[str]) -> List[Dict]:
|
|
704
704
|
"""
|
|
705
705
|
Retrieves atoms by id
|
|
@@ -766,7 +766,7 @@ class AtlasDataset(AtlasClass):
|
|
|
766
766
|
|
|
767
767
|
* **identifier** - The dataset identifier in the form `dataset` or `organization/dataset`. If no organization is passed, the organization tied to the API key you logged in to Nomic with will be used.
|
|
768
768
|
* **description** - A description for the dataset.
|
|
769
|
-
* **unique_id_field** -
|
|
769
|
+
* **unique_id_field** - A field that uniquely identifies each data point.
|
|
770
770
|
* **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
|
|
771
771
|
* **dataset_id** - An alternative way to load a dataset is by passing the dataset_id directly. This only works if a dataset exists.
|
|
772
772
|
"""
|
|
@@ -805,13 +805,6 @@ class AtlasDataset(AtlasClass):
|
|
|
805
805
|
dataset_id = dataset["id"]
|
|
806
806
|
|
|
807
807
|
if dataset_id is None: # if there is no existing project, make a new one.
|
|
808
|
-
if unique_id_field is None: # if not all parameters are specified, we weren't trying to make a project
|
|
809
|
-
raise ValueError(f"Dataset `{identifier}` does not exist.")
|
|
810
|
-
|
|
811
|
-
# if modality is None:
|
|
812
|
-
# raise ValueError("You must specify a modality when creating a new dataset.")
|
|
813
|
-
#
|
|
814
|
-
# assert modality in ['text', 'embedding'], "Modality must be either `text` or `embedding`"
|
|
815
808
|
assert identifier is not None
|
|
816
809
|
|
|
817
810
|
dataset_id = self._create_project(
|
|
@@ -840,7 +833,7 @@ class AtlasDataset(AtlasClass):
|
|
|
840
833
|
self,
|
|
841
834
|
identifier: str,
|
|
842
835
|
description: Optional[str],
|
|
843
|
-
unique_id_field: str,
|
|
836
|
+
unique_id_field: Optional[str] = None,
|
|
844
837
|
is_public: bool = True,
|
|
845
838
|
):
|
|
846
839
|
"""
|
|
@@ -852,7 +845,7 @@ class AtlasDataset(AtlasClass):
|
|
|
852
845
|
|
|
853
846
|
* **identifier** - The identifier for the dataset.
|
|
854
847
|
* **description** - A description for the dataset.
|
|
855
|
-
* **unique_id_field** -
|
|
848
|
+
* **unique_id_field** - A field that uniquely identifies each data point.
|
|
856
849
|
* **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
|
|
857
850
|
|
|
858
851
|
**Returns:** project_id on success.
|
|
@@ -865,15 +858,6 @@ class AtlasDataset(AtlasClass):
|
|
|
865
858
|
if "/" in identifier:
|
|
866
859
|
org_name = identifier.split("/")[0]
|
|
867
860
|
logger.info(f"Organization name: `{org_name}`")
|
|
868
|
-
# supported_modalities = ['text', 'embedding']
|
|
869
|
-
# if modality not in supported_modalities:
|
|
870
|
-
# msg = 'Tried to create dataset with modality: {}, but Atlas only supports: {}'.format(
|
|
871
|
-
# modality, supported_modalities
|
|
872
|
-
# )
|
|
873
|
-
# raise ValueError(msg)
|
|
874
|
-
|
|
875
|
-
if unique_id_field is None:
|
|
876
|
-
raise ValueError("You must specify a unique id field")
|
|
877
861
|
if description is None:
|
|
878
862
|
description = ""
|
|
879
863
|
response = requests.post(
|
|
@@ -884,7 +868,6 @@ class AtlasDataset(AtlasClass):
|
|
|
884
868
|
"project_name": project_slug,
|
|
885
869
|
"description": description,
|
|
886
870
|
"unique_id_field": unique_id_field,
|
|
887
|
-
# 'modality': modality,
|
|
888
871
|
"is_public": is_public,
|
|
889
872
|
},
|
|
890
873
|
)
|
|
@@ -939,11 +922,12 @@ class AtlasDataset(AtlasClass):
|
|
|
939
922
|
|
|
940
923
|
@property
|
|
941
924
|
def id(self) -> str:
|
|
942
|
-
"""The
|
|
925
|
+
"""The ID of the dataset."""
|
|
943
926
|
return self.meta["id"]
|
|
944
927
|
|
|
945
928
|
@property
|
|
946
929
|
def id_field(self) -> str:
|
|
930
|
+
"""The unique_id_field of the dataset."""
|
|
947
931
|
return self.meta["unique_id_field"]
|
|
948
932
|
|
|
949
933
|
@property
|
|
@@ -1147,7 +1131,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1147
1131
|
colorable_fields = []
|
|
1148
1132
|
|
|
1149
1133
|
for field in self.dataset_fields:
|
|
1150
|
-
if field not in [
|
|
1134
|
+
if field not in [indexed_field] and not field.startswith("_"):
|
|
1151
1135
|
colorable_fields.append(field)
|
|
1152
1136
|
|
|
1153
1137
|
build_template = {}
|
|
@@ -1410,98 +1394,126 @@ class AtlasDataset(AtlasClass):
|
|
|
1410
1394
|
Uploads blobs to the server and associates them with the data.
|
|
1411
1395
|
Blobs must reference objects stored locally
|
|
1412
1396
|
"""
|
|
1397
|
+
data_as_table: pa.Table
|
|
1413
1398
|
if isinstance(data, DataFrame):
|
|
1414
|
-
|
|
1399
|
+
data_as_table = pa.Table.from_pandas(data)
|
|
1415
1400
|
elif isinstance(data, list):
|
|
1416
|
-
|
|
1417
|
-
elif
|
|
1401
|
+
data_as_table = pa.Table.from_pylist(data)
|
|
1402
|
+
elif isinstance(data, pa.Table):
|
|
1403
|
+
data_as_table = data
|
|
1404
|
+
else:
|
|
1418
1405
|
raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.")
|
|
1419
1406
|
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
# add hash to data as _blob_hash
|
|
1425
|
-
# set indexed_field to _blob_hash
|
|
1426
|
-
# call _add_data
|
|
1427
|
-
|
|
1428
|
-
# Cast self id field to string for merged data lower down on function
|
|
1429
|
-
data = data.set_column( # type: ignore
|
|
1430
|
-
data.schema.get_field_index(self.id_field), self.id_field, pc.cast(data[self.id_field], pa.string()) # type: ignore
|
|
1431
|
-
)
|
|
1407
|
+
# Compute dataset length
|
|
1408
|
+
data_length = len(data_as_table)
|
|
1409
|
+
if data_length != len(blobs):
|
|
1410
|
+
raise ValueError(f"Number of data points ({data_length}) must match number of blobs ({len(blobs)})")
|
|
1432
1411
|
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
images = []
|
|
1439
|
-
for
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1412
|
+
TEMP_ID_COLUMN = "_nomic_internal_temp_id"
|
|
1413
|
+
temp_id_values = [str(i) for i in range(data_length)]
|
|
1414
|
+
data_as_table = data_as_table.append_column(TEMP_ID_COLUMN, pa.array(temp_id_values, type=pa.string()))
|
|
1415
|
+
blob_upload_endpoint = "/v1/project/data/add/blobs"
|
|
1416
|
+
actual_temp_ids = data_as_table[TEMP_ID_COLUMN].to_pylist()
|
|
1417
|
+
images = [] # List of (temp_id, image_bytes)
|
|
1418
|
+
for i in tqdm(range(data_length), desc="Processing images"):
|
|
1419
|
+
current_temp_id = actual_temp_ids[i]
|
|
1420
|
+
blob_item = blobs[i]
|
|
1421
|
+
|
|
1422
|
+
processed_blob_value = None
|
|
1423
|
+
if (isinstance(blob_item, str) or isinstance(blob_item, Path)) and os.path.exists(blob_item):
|
|
1424
|
+
image = Image.open(blob_item).convert("RGB")
|
|
1444
1425
|
if image.height > 512 or image.width > 512:
|
|
1445
1426
|
image = image.resize((512, 512))
|
|
1446
1427
|
buffered = BytesIO()
|
|
1447
1428
|
image.save(buffered, format="JPEG")
|
|
1448
|
-
|
|
1449
|
-
elif isinstance(
|
|
1450
|
-
|
|
1451
|
-
elif isinstance(
|
|
1452
|
-
|
|
1453
|
-
if
|
|
1454
|
-
|
|
1429
|
+
processed_blob_value = buffered.getvalue()
|
|
1430
|
+
elif isinstance(blob_item, bytes):
|
|
1431
|
+
processed_blob_value = blob_item
|
|
1432
|
+
elif isinstance(blob_item, Image.Image):
|
|
1433
|
+
img_pil = blob_item.convert("RGB") # Ensure it's PIL Image for methods
|
|
1434
|
+
if img_pil.height > 512 or img_pil.width > 512:
|
|
1435
|
+
img_pil = img_pil.resize((512, 512))
|
|
1455
1436
|
buffered = BytesIO()
|
|
1456
|
-
|
|
1457
|
-
|
|
1437
|
+
img_pil.save(buffered, format="JPEG")
|
|
1438
|
+
processed_blob_value = buffered.getvalue()
|
|
1458
1439
|
else:
|
|
1459
|
-
raise ValueError(
|
|
1440
|
+
raise ValueError(
|
|
1441
|
+
f"Invalid blob type for item at index {i} (temp_id: {current_temp_id}). Must be a path, bytes, or PIL Image. Got: {type(blob_item)}"
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
if processed_blob_value is not None:
|
|
1445
|
+
images.append((current_temp_id, processed_blob_value))
|
|
1460
1446
|
|
|
1461
1447
|
batch_size = 40
|
|
1462
1448
|
num_workers = 10
|
|
1463
1449
|
|
|
1464
|
-
def send_request(
|
|
1465
|
-
image_batch = images[
|
|
1466
|
-
|
|
1467
|
-
|
|
1450
|
+
def send_request(batch_start_index):
|
|
1451
|
+
image_batch = images[batch_start_index : batch_start_index + batch_size]
|
|
1452
|
+
temp_ids_in_batch = [item_id for item_id, _ in image_batch]
|
|
1453
|
+
blobs_for_api = [("blobs", blob_val) for _, blob_val in image_batch]
|
|
1454
|
+
|
|
1468
1455
|
response = requests.post(
|
|
1469
1456
|
self.atlas_api_path + blob_upload_endpoint,
|
|
1470
1457
|
headers=self.header,
|
|
1471
|
-
data={"dataset_id": self.id},
|
|
1472
|
-
files=
|
|
1458
|
+
data={"dataset_id": self.id}, # self.id is project_id
|
|
1459
|
+
files=blobs_for_api,
|
|
1473
1460
|
)
|
|
1474
1461
|
if response.status_code != 200:
|
|
1475
|
-
|
|
1476
|
-
|
|
1462
|
+
failed_ids_sample = temp_ids_in_batch[:5]
|
|
1463
|
+
logger.error(
|
|
1464
|
+
f"Blob upload request failed for batch starting with temp_ids: {failed_ids_sample}. Status: {response.status_code}, Response: {response.text}"
|
|
1465
|
+
)
|
|
1466
|
+
raise Exception(f"Blob upload failed: {response.text}")
|
|
1467
|
+
return {temp_id: blob_hash for temp_id, blob_hash in zip(temp_ids_in_batch, response.json()["hashes"])}
|
|
1477
1468
|
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1469
|
+
upload_pbar = pbar # Use passed-in pbar if available
|
|
1470
|
+
close_upload_pbar_locally = False
|
|
1471
|
+
if upload_pbar is None:
|
|
1472
|
+
upload_pbar = tqdm(total=len(images), desc="Uploading blobs to Atlas")
|
|
1473
|
+
close_upload_pbar_locally = True
|
|
1481
1474
|
|
|
1482
|
-
|
|
1483
|
-
returned_ids = []
|
|
1475
|
+
returned_temp_ids = []
|
|
1484
1476
|
returned_hashes = []
|
|
1477
|
+
succeeded_uploads = 0
|
|
1485
1478
|
|
|
1486
|
-
succeeded = 0
|
|
1487
1479
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
1488
|
-
futures = {executor.submit(send_request, i): i for i in range(0, len(
|
|
1480
|
+
futures = {executor.submit(send_request, i): i for i in range(0, len(images), batch_size)}
|
|
1489
1481
|
|
|
1490
1482
|
for future in concurrent.futures.as_completed(futures):
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1483
|
+
try:
|
|
1484
|
+
response_dict = future.result() # This is {temp_id: blob_hash}
|
|
1485
|
+
for temp_id, blob_hash_val in response_dict.items():
|
|
1486
|
+
returned_temp_ids.append(temp_id)
|
|
1487
|
+
returned_hashes.append(blob_hash_val)
|
|
1488
|
+
succeeded_uploads += len(response_dict)
|
|
1489
|
+
if upload_pbar:
|
|
1490
|
+
upload_pbar.update(len(response_dict))
|
|
1491
|
+
except Exception as e:
|
|
1492
|
+
logger.error(f"An error occurred during blob upload processing for a batch: {e}")
|
|
1493
|
+
# Optionally, collect failed batch info here if needed for partial success
|
|
1494
|
+
|
|
1495
|
+
if close_upload_pbar_locally and upload_pbar:
|
|
1496
|
+
upload_pbar.close()
|
|
1497
|
+
|
|
1498
|
+
hash_schema = pa.schema([(TEMP_ID_COLUMN, pa.string()), ("_blob_hash", pa.string())])
|
|
1499
|
+
|
|
1500
|
+
merged_data_as_table: pa.Table
|
|
1501
|
+
if succeeded_uploads > 0: # Only create hash_tb if there are successful uploads
|
|
1502
|
+
hash_tb = pa.Table.from_pydict(
|
|
1503
|
+
{TEMP_ID_COLUMN: returned_temp_ids, "_blob_hash": returned_hashes}, schema=hash_schema
|
|
1504
|
+
)
|
|
1505
|
+
merged_data_as_table = data_as_table.join(right_table=hash_tb, keys=TEMP_ID_COLUMN, join_type="left outer")
|
|
1506
|
+
else: # No successful uploads, so no hashes to merge, but keep original data structure
|
|
1507
|
+
# Need to ensure _blob_hash column is added with nulls, and id_field is present
|
|
1508
|
+
if "_blob_hash" not in data_as_table.column_names:
|
|
1509
|
+
data_as_table = data_as_table.append_column(
|
|
1510
|
+
"_blob_hash", pa.nulls(data_as_table.num_rows, type=pa.string())
|
|
1511
|
+
)
|
|
1512
|
+
merged_data_as_table = data_as_table
|
|
1500
1513
|
|
|
1501
|
-
|
|
1502
|
-
merged_data = data.join(right_table=hash_tb, keys=self.id_field) # type: ignore
|
|
1514
|
+
merged_data_as_table = merged_data_as_table.drop_columns([TEMP_ID_COLUMN])
|
|
1503
1515
|
|
|
1504
|
-
self._add_data(
|
|
1516
|
+
self._add_data(merged_data_as_table, pbar=pbar) # Pass original pbar argument
|
|
1505
1517
|
|
|
1506
1518
|
def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None):
|
|
1507
1519
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|