nomic 3.5.0__tar.gz → 3.5.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.5.0
3
+ Version: 3.5.2
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -3,7 +3,6 @@ This class allows for programmatic interactions with Atlas - Nomic's neural data
3
3
  or in a Jupyter Notebook to organize and interact with your unstructured data.
4
4
  """
5
5
 
6
- import uuid
7
6
  from typing import Dict, Iterable, List, Optional, Union
8
7
 
9
8
  import numpy as np
@@ -42,7 +41,7 @@ def map_data(
42
41
  embeddings: An [N,d] numpy array containing the N embeddings to add.
43
42
  identifier: A name for your dataset that is used to generate the dataset identifier. A unique name will be chosen if not supplied.
44
43
  description: The description of your dataset
45
- id_field: Specify your data unique id field. This field can be up 36 characters in length. If not specified, one will be created for you named `id_`.
44
+ id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
46
45
  is_public: Should the dataset be accessible outside your Nomic Atlas organization.
47
46
  indexed_field: The text field from the dataset that will be used to create embeddings, which determines the layout of the data map in Atlas. Required for text data but won't have an impact if uploading embeddings or image blobs.
48
47
  projection: Options for configuring the 2D projection algorithm.
@@ -86,9 +85,6 @@ def map_data(
86
85
  # default to vision v1.5
87
86
  embedding_model = NomicEmbedOptions(model="nomic-embed-vision-v1.5")
88
87
 
89
- if id_field is None:
90
- id_field = ATLAS_DEFAULT_ID_FIELD
91
-
92
88
  project_name = get_random_name()
93
89
 
94
90
  dataset_name = project_name
@@ -100,38 +96,6 @@ def map_data(
100
96
  if description:
101
97
  description = description
102
98
 
103
- # no metadata was specified
104
- added_id_field = False
105
-
106
- if data is None:
107
- added_id_field = True
108
- if embeddings is not None:
109
- data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(embeddings))]
110
- elif blobs is not None:
111
- data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(blobs))]
112
- else:
113
- raise ValueError("You must specify either data, embeddings, or blobs")
114
-
115
- if id_field == ATLAS_DEFAULT_ID_FIELD and data is not None:
116
- if isinstance(data, list) and id_field not in data[0]:
117
- added_id_field = True
118
- for i in range(len(data)):
119
- # do not modify object the user passed in - also ensures IDs are unique if two input datums are the same *object*
120
- data[i] = data[i].copy()
121
- data[i][id_field] = b64int(i)
122
- elif isinstance(data, DataFrame) and id_field not in data.columns:
123
- data[id_field] = [b64int(i) for i in range(data.shape[0])]
124
- added_id_field = True
125
- elif isinstance(data, pa.Table) and not id_field in data.column_names: # type: ignore
126
- ids = pa.array([b64int(i) for i in range(len(data))])
127
- data = data.append_column(id_field, ids) # type: ignore
128
- added_id_field = True
129
- elif id_field not in data[0]:
130
- raise ValueError("map_data data must be a list of dicts, a pandas dataframe, or a pyarrow table")
131
-
132
- if added_id_field:
133
- logger.warning("An ID field was not specified in your data so one was generated for you in insertion order.")
134
-
135
99
  dataset = AtlasDataset(
136
100
  identifier=dataset_name, description=description, unique_id_field=id_field, is_public=is_public
137
101
  )
@@ -202,7 +166,7 @@ def map_embeddings(
202
166
  Args:
203
167
  embeddings: An [N,d] numpy array containing the batch of N embeddings to add.
204
168
  data: An [N,] element list of dictionaries containing metadata for each embedding.
205
- id_field: Specify your data unique id field. This field can be up 36 characters in length. If not specified, one will be created for you named `id_`.
169
+ id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
206
170
  name: A name for your dataset. Specify in the format `organization/project` to create in a specific organization.
207
171
  description: A description for your map.
208
172
  is_public: Should this embedding map be public? Private maps can only be accessed by members of your organization.
@@ -250,7 +214,7 @@ def map_text(
250
214
  Args:
251
215
  data: An [N,] element iterable of dictionaries containing metadata for each embedding.
252
216
  indexed_field: The name the data field containing the text your want to map.
253
- id_field: Specify your data unique id field. This field can be up 36 characters in length. If not specified, one will be created for you named `id_`.
217
+ id_field: Specify a field that uniquely identifies each datapoint. This field can be up 36 characters in length.
254
218
  name: A name for your dataset. Specify in the format `organization/project` to create in a specific organization.
255
219
  description: A description for your map.
256
220
  build_topic_model: Builds a hierarchical topic model over your data to discover patterns.
@@ -25,7 +25,6 @@ class AtlasMapDuplicates:
25
25
 
26
26
  def __init__(self, projection: "AtlasProjection"): # type: ignore
27
27
  self.projection = projection
28
- self.id_field = self.projection.dataset.id_field
29
28
 
30
29
  duplicate_columns = [
31
30
  (field, sidecar)
@@ -40,7 +39,13 @@ class AtlasMapDuplicates:
40
39
 
41
40
  self._duplicate_column = duplicate_columns[0]
42
41
  self._cluster_column = cluster_columns[0]
42
+ self.duplicate_field = None
43
+ self.cluster_field = None
43
44
  self._tb = None
45
+ self._has_unique_id_field = (
46
+ "unique_id_field" in self.projection.dataset.meta
47
+ and self.projection.dataset.meta["unique_id_field"] is not None
48
+ )
44
49
 
45
50
  def _load_duplicates(self):
46
51
  """
@@ -51,30 +56,80 @@ class AtlasMapDuplicates:
51
56
  self.duplicate_field = self._duplicate_column[0].lstrip("_")
52
57
  self.cluster_field = self._cluster_column[0].lstrip("_")
53
58
  logger.info("Loading duplicates")
59
+ id_field_name = "_position_index"
60
+ if self._has_unique_id_field:
61
+ id_field_name = self.projection.dataset.meta["unique_id_field"]
62
+
54
63
  for key in tqdm(self.projection._manifest["key"].to_pylist()):
55
- # Use datum id as root table
56
- tb = feather.read_table(
57
- self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
58
- )
59
- path = self.projection.tile_destination
64
+ # Use datum id as root table if available, otherwise create synthetic IDs
65
+ if self._has_unique_id_field:
66
+ try:
67
+ datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
68
+ tb = feather.read_table(datum_id_path, memory_map=True)
69
+ except (FileNotFoundError, pa.ArrowInvalid):
70
+ # Create a synthetic ID table
71
+ tb = self._create_synthetic_id_table(key, duplicate_sidecar)
72
+ else:
73
+ # Create synthetic IDs when no unique_id_field is available
74
+ tb = self._create_synthetic_id_table(key, duplicate_sidecar)
60
75
 
76
+ path = self.projection.tile_destination
61
77
  if duplicate_sidecar == "":
62
78
  path = path / Path(key).with_suffix(".feather")
63
79
  else:
64
80
  path = path / Path(key).with_suffix(f".{duplicate_sidecar}.feather")
65
81
 
66
- duplicate_tb = feather.read_table(path, memory_map=True)
67
- for field in (self._duplicate_column[0], self._cluster_column[0]):
68
- tb = tb.append_column(field, duplicate_tb[field])
69
- tbs.append(tb)
70
- self._tb = pa.concat_tables(tbs).rename_columns([self.id_field, self.duplicate_field, self.cluster_field])
82
+ try:
83
+ duplicate_tb = feather.read_table(path, memory_map=True)
84
+ for field in (self._duplicate_column[0], self._cluster_column[0]):
85
+ tb = tb.append_column(field, duplicate_tb[field])
86
+ tbs.append(tb)
87
+ except (FileNotFoundError, pa.ArrowInvalid) as e:
88
+ logger.warning(f"Error loading duplicate data for key {key}: {e}")
89
+ continue
90
+
91
+ if not tbs:
92
+ raise ValueError("No duplicate data could be loaded. Duplicate data files may be missing or corrupt.")
93
+
94
+ self._tb = pa.concat_tables(tbs).rename_columns([id_field_name, self.duplicate_field, self.cluster_field])
95
+
96
+ def _create_synthetic_id_table(self, key, sidecar):
97
+ """
98
+ Create a synthetic table with position indices when datum_id file isn't available
99
+ or when unique_id_field isn't specified.
100
+ """
101
+ # Try to determine the size of the table by loading the duplicate sidecar
102
+ path = self.projection.tile_destination
103
+ if sidecar == "":
104
+ path = path / Path(key).with_suffix(".feather")
105
+ else:
106
+ path = path / Path(key).with_suffix(f".{sidecar}.feather")
107
+
108
+ try:
109
+ sidecar_tb = feather.read_table(path, memory_map=True)
110
+ size = len(sidecar_tb)
111
+ # Create a table with position indices as IDs
112
+ position_indices = [f"pos_{i}" for i in range(size)]
113
+ return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
114
+ except Exception as e:
115
+ logger.error(f"Failed to create synthetic IDs for {key}: {e}")
116
+ # Return an empty table as fallback
117
+ return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
71
118
 
72
119
  def _download_duplicates(self):
73
120
  """
74
121
  Downloads the feather tree for duplicates.
75
122
  """
76
123
  logger.info("Downloading duplicates")
77
- self.projection._download_sidecar("datum_id", overwrite=False)
124
+
125
+ # Only download datum_id if we have a unique ID field
126
+ if self._has_unique_id_field:
127
+ try:
128
+ self.projection._download_sidecar("datum_id", overwrite=False)
129
+ except ValueError as e:
130
+ logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
131
+ self._has_unique_id_field = False
132
+
78
133
  assert self._cluster_column[1] == self._duplicate_column[1], "Cluster and duplicate should be in same sidecar"
79
134
  self.projection._download_sidecar(self._duplicate_column[1], overwrite=False)
80
135
 
@@ -100,17 +155,24 @@ class AtlasMapDuplicates:
100
155
 
101
156
  def deletion_candidates(self) -> List[str]:
102
157
  """
103
-
104
158
  Returns:
105
159
  The ids for all data points which are semantic duplicates and are candidates for being deleted from the dataset. If you remove these data points from your dataset, your dataset will be semantically deduplicated.
106
160
  """
107
- dupes = self.tb[self.id_field].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
161
+ id_field_name = "_position_index"
162
+ if self._has_unique_id_field:
163
+ id_field_name = self.projection.dataset.meta["unique_id_field"]
164
+
165
+ dupes = self.tb[id_field_name].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
108
166
  return dupes.to_pylist()
109
167
 
110
168
  def __repr__(self) -> str:
169
+ id_field_name = "_position_index"
170
+ if self._has_unique_id_field:
171
+ id_field_name = self.projection.dataset.meta["unique_id_field"]
172
+
111
173
  repr = f"===Atlas Duplicates for ({self.projection})\n"
112
174
  duplicate_count = len(
113
- self.tb[self.id_field].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
175
+ self.tb[id_field_name].filter(pa.compute.equal(self.tb[self.duplicate_field], "deletion candidate")) # type: ignore
114
176
  )
115
177
  cluster_count = len(self.tb[self.cluster_field].value_counts())
116
178
  repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n"
@@ -125,7 +187,6 @@ class AtlasMapTopics:
125
187
  def __init__(self, projection: "AtlasProjection"): # type: ignore
126
188
  self.projection = projection
127
189
  self.dataset = projection.dataset
128
- self.id_field = self.projection.dataset.id_field
129
190
  self._metadata = None
130
191
  self._hierarchy = None
131
192
  self._topic_columns = [
@@ -134,6 +195,9 @@ class AtlasMapTopics:
134
195
  assert len(self._topic_columns) > 0, "Topic modeling has not yet been run on this map."
135
196
  self.depth = len(self._topic_columns)
136
197
  self._tb = None
198
+ self._has_unique_id_field = (
199
+ "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
200
+ )
137
201
 
138
202
  def _load_topics(self):
139
203
  """
@@ -149,45 +213,96 @@ class AtlasMapTopics:
149
213
  # Should just be one sidecar
150
214
  topic_sidecar = set([sidecar for _, sidecar in self._topic_columns]).pop()
151
215
  logger.info("Loading topics")
216
+ id_field_name = "_position_index"
217
+ if self._has_unique_id_field:
218
+ id_field_name = self.dataset.meta["unique_id_field"]
219
+
152
220
  for key in tqdm(self.projection._manifest["key"].to_pylist()):
153
- # Use datum id as root table
154
- tb = feather.read_table(
155
- self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
156
- )
221
+ # Use datum id as root table if available, otherwise create a synthetic index
222
+ if self._has_unique_id_field:
223
+ try:
224
+ datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
225
+ tb = feather.read_table(datum_id_path, memory_map=True)
226
+ except (FileNotFoundError, pa.ArrowInvalid):
227
+ # If datum_id file doesn't exist, create a table with a position index
228
+ logger.warning(f"Datum ID file not found for key {key}. Creating synthetic IDs.")
229
+ tb = self._create_synthetic_id_table(key, topic_sidecar)
230
+ else:
231
+ # Create a table with a position index when no unique_id_field is available
232
+ tb = self._create_synthetic_id_table(key, topic_sidecar)
233
+
157
234
  path = self.projection.tile_destination
158
235
  if topic_sidecar == "":
159
236
  path = path / Path(key).with_suffix(".feather")
160
237
  else:
161
238
  path = path / Path(key).with_suffix(f".{topic_sidecar}.feather")
162
239
 
163
- topic_tb = feather.read_table(path, memory_map=True)
164
- # Do this in depth order
165
- for d in range(1, self.depth + 1):
166
- column = f"_topic_depth_{d}"
167
- if integer_topics:
168
- column = f"_topic_depth_{d}_int"
169
- topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
170
- assert label_df is not None
171
- topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
172
- topic_ids_to_label, on="topic_id", how="right"
173
- )
174
- new_column = f"_topic_depth_{d}"
175
- tb = tb.append_column(
176
- new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
177
- )
178
- else:
179
- tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
180
- tbs.append(tb)
240
+ try:
241
+ topic_tb = feather.read_table(path, memory_map=True)
242
+ # Do this in depth order
243
+ for d in range(1, self.depth + 1):
244
+ column = f"_topic_depth_{d}"
245
+ if integer_topics:
246
+ column = f"_topic_depth_{d}_int"
247
+ topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
248
+ assert label_df is not None
249
+ topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
250
+ topic_ids_to_label, on="topic_id", how="right"
251
+ )
252
+ new_column = f"_topic_depth_{d}"
253
+ tb = tb.append_column(
254
+ new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
255
+ )
256
+ else:
257
+ tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
258
+ tbs.append(tb)
259
+ except (FileNotFoundError, pa.ArrowInvalid) as e:
260
+ logger.warning(f"Error loading topic data for key {key}: {e}")
261
+ continue
181
262
 
182
- renamed_columns = [self.id_field] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
263
+ if not tbs:
264
+ raise ValueError("No topic data could be loaded. Topic data files may be missing or corrupt.")
265
+
266
+ renamed_columns = [id_field_name] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
183
267
  self._tb = pa.concat_tables(tbs).rename_columns(renamed_columns)
184
268
 
269
+ def _create_synthetic_id_table(self, key, sidecar):
270
+ """
271
+ Create a synthetic table with position indices when datum_id file isn't available
272
+ or when unique_id_field isn't specified.
273
+ """
274
+ # Try to determine the size of the table by loading the topic sidecar
275
+ path = self.projection.tile_destination
276
+ if sidecar == "":
277
+ path = path / Path(key).with_suffix(".feather")
278
+ else:
279
+ path = path / Path(key).with_suffix(f".{sidecar}.feather")
280
+
281
+ try:
282
+ topic_tb = feather.read_table(path, memory_map=True)
283
+ size = len(topic_tb)
284
+ # Create a table with position indices as IDs
285
+ position_indices = [f"pos_{i}" for i in range(size)]
286
+ return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
287
+ except Exception as e:
288
+ logger.error(f"Failed to create synthetic IDs for {key}: {e}")
289
+ # Return an empty table as fallback
290
+ return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
291
+
185
292
  def _download_topics(self):
186
293
  """
187
294
  Downloads the feather tree for topics.
188
295
  """
189
296
  logger.info("Downloading topics")
190
- self.projection._download_sidecar("datum_id", overwrite=False)
297
+
298
+ # Only download datum_id if we have a unique ID field
299
+ if self._has_unique_id_field:
300
+ try:
301
+ self.projection._download_sidecar("datum_id", overwrite=False)
302
+ except ValueError as e:
303
+ logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
304
+ self._has_unique_id_field = False
305
+
191
306
  topic_sidecars = set([sidecar for _, sidecar in self._topic_columns])
192
307
  assert len(topic_sidecars) == 1, "Multiple topic sidecars found."
193
308
  self.projection._download_sidecar(topic_sidecars.pop(), overwrite=False)
@@ -286,7 +401,10 @@ class AtlasMapTopics:
286
401
  raise ValueError("Topic depth out of range.")
287
402
 
288
403
  # Unique datum id column to aggregate
289
- datum_id_col = self.dataset.meta["unique_id_field"]
404
+ datum_id_col = "_position_index"
405
+ if "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None:
406
+ datum_id_col = self.dataset.meta["unique_id_field"]
407
+
290
408
  df = self.df
291
409
 
292
410
  topic_datum_dict = df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict()
@@ -333,8 +451,12 @@ class AtlasMapTopics:
333
451
  A list of `{topic, count}` dictionaries, sorted from largest count to smallest count.
334
452
  """
335
453
  data = AtlasMapData(self.projection, fields=[time_field])
336
- time_data = data.tb.select([self.id_field, time_field])
337
- merged_tb = self.tb.join(time_data, self.id_field, join_type="inner").combine_chunks()
454
+ id_field_name = "_position_index"
455
+ if "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None:
456
+ id_field_name = self.dataset.meta["unique_id_field"]
457
+
458
+ time_data = data.tb.select([id_field_name, time_field])
459
+ merged_tb = self.tb.join(time_data, id_field_name, join_type="inner").combine_chunks()
338
460
 
339
461
  del time_data # free up memory
340
462
 
@@ -343,12 +465,12 @@ class AtlasMapTopics:
343
465
  topic_densities = {}
344
466
  for depth in range(1, self.depth + 1):
345
467
  topic_column = f"topic_depth_{depth}"
346
- topic_counts = merged_tb.group_by(topic_column).aggregate([(self.id_field, "count")]).to_pandas()
468
+ topic_counts = merged_tb.group_by(topic_column).aggregate([(id_field_name, "count")]).to_pandas()
347
469
  for _, row in topic_counts.iterrows():
348
470
  topic = row[topic_column]
349
471
  if topic not in topic_densities:
350
472
  topic_densities[topic] = 0
351
- topic_densities[topic] += row[self.id_field + "_count"]
473
+ topic_densities[topic] += row[id_field_name + "_count"]
352
474
  return topic_densities
353
475
 
354
476
  def vector_search_topics(self, queries: np.ndarray, k: int = 32, depth: int = 3) -> Dict:
@@ -448,10 +570,12 @@ class AtlasMapEmbeddings:
448
570
 
449
571
  def __init__(self, projection: "AtlasProjection"): # type: ignore
450
572
  self.projection = projection
451
- self.id_field = self.projection.dataset.id_field
452
573
  self.dataset = projection.dataset
453
574
  self._tb: pa.Table = None
454
575
  self._latent = None
576
+ self._has_unique_id_field = (
577
+ "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
578
+ )
455
579
 
456
580
  @property
457
581
  def df(self):
@@ -481,23 +605,81 @@ class AtlasMapEmbeddings:
481
605
  tbs = []
482
606
  coord_sidecar = self.projection._get_sidecar_from_field("x")
483
607
  for key in tqdm(self.projection._manifest["key"].to_pylist()):
484
- # Use datum id as root table
485
- tb = feather.read_table(
486
- self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
487
- )
608
+ # Use datum id as root table if available, otherwise create synthetic IDs
609
+ if self._has_unique_id_field:
610
+ try:
611
+ datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
612
+ tb = feather.read_table(datum_id_path, memory_map=True)
613
+ except (FileNotFoundError, pa.ArrowInvalid):
614
+ # Create a synthetic ID table
615
+ tb = self._create_synthetic_id_table(key, coord_sidecar)
616
+ else:
617
+ # Create synthetic IDs when no unique_id_field is available
618
+ tb = self._create_synthetic_id_table(key, coord_sidecar)
619
+
488
620
  path = self.projection.tile_destination
489
621
  if coord_sidecar == "":
490
622
  path = path / Path(key).with_suffix(".feather")
491
623
  else:
492
624
  path = path / Path(key).with_suffix(f".{coord_sidecar}.feather")
493
- carfile = feather.read_table(path, memory_map=True)
494
- for col in carfile.column_names:
495
- if col in ["x", "y"]:
496
- tb = tb.append_column(col, carfile[col])
497
- tbs.append(tb)
625
+
626
+ try:
627
+ carfile = feather.read_table(path, memory_map=True)
628
+ for col in carfile.column_names:
629
+ if col in ["x", "y"]:
630
+ tb = tb.append_column(col, carfile[col])
631
+ tbs.append(tb)
632
+ except (FileNotFoundError, pa.ArrowInvalid) as e:
633
+ logger.warning(f"Error loading embedding data for key {key}: {e}")
634
+ continue
635
+
636
+ if not tbs:
637
+ raise ValueError("No embedding data could be loaded. Embedding data files may be missing or corrupt.")
638
+
498
639
  self._tb = pa.concat_tables(tbs)
499
640
  return self._tb
500
641
 
642
+ def _create_synthetic_id_table(self, key, sidecar):
643
+ """
644
+ Create a synthetic table with position indices when datum_id file isn't available
645
+ or when unique_id_field isn't specified.
646
+ """
647
+ # Try to determine the size of the table by loading the coordinate sidecar
648
+ path = self.projection.tile_destination
649
+ if sidecar == "":
650
+ path = path / Path(key).with_suffix(".feather")
651
+ else:
652
+ path = path / Path(key).with_suffix(f".{sidecar}.feather")
653
+
654
+ try:
655
+ coord_tb = feather.read_table(path, memory_map=True)
656
+ size = len(coord_tb)
657
+ # Create a table with position indices as IDs
658
+ position_indices = [f"pos_{i}" for i in range(size)]
659
+ return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
660
+ except Exception as e:
661
+ logger.error(f"Failed to create synthetic IDs for {key}: {e}")
662
+ # Return an empty table as fallback
663
+ return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
664
+
665
+ def _download_projected(self) -> List[Path]:
666
+ """
667
+ Downloads the feather tree for projection coordinates.
668
+ """
669
+ logger.info("Downloading projected embeddings")
670
+ # Note that y coord should be in same sidecar
671
+ coord_sidecar = self.projection._get_sidecar_from_field("x")
672
+
673
+ # Only download datum_id if we have a unique ID field
674
+ if self._has_unique_id_field:
675
+ try:
676
+ self.projection._download_sidecar("datum_id", overwrite=False)
677
+ except ValueError as e:
678
+ logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
679
+ self._has_unique_id_field = False
680
+
681
+ return self.projection._download_sidecar(coord_sidecar, overwrite=False)
682
+
501
683
  @property
502
684
  def projected(self) -> pd.DataFrame:
503
685
  """
@@ -531,16 +713,6 @@ class AtlasMapEmbeddings:
531
713
  all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
532
714
  return np.vstack(all_embeddings)
533
715
 
534
- def _download_projected(self) -> List[Path]:
535
- """
536
- Downloads the feather tree for projection coordinates.
537
- """
538
- logger.info("Downloading projected embeddings")
539
- # Note that y coord should be in same sidecar
540
- coord_sidecar = self.projection._get_sidecar_from_field("x")
541
- self.projection._download_sidecar("datum_id", overwrite=False)
542
- return self.projection._download_sidecar(coord_sidecar, overwrite=False)
543
-
544
716
  def _download_latent(self) -> List[Path]:
545
717
  """
546
718
  Downloads the feather tree for embeddings.
@@ -670,12 +842,17 @@ class AtlasMapTags:
670
842
  def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False): # type: ignore
671
843
  self.projection = projection
672
844
  self.dataset = projection.dataset
673
- self.id_field = self.projection.dataset.id_field
674
845
  # Pre-fetch datum ids first upon initialization
675
- try:
676
- self.projection._download_sidecar("datum_id")
677
- except Exception:
678
- raise ValueError("Failed to fetch datum ids which is required to load tags.")
846
+ self._has_unique_id_field = (
847
+ "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
848
+ )
849
+
850
+ if self._has_unique_id_field:
851
+ try:
852
+ self.projection._download_sidecar("datum_id")
853
+ except Exception as e:
854
+ logger.warning(f"Failed to fetch datum ids: {e}. Will use synthetic IDs.")
855
+ self._has_unique_id_field = False
679
856
  self.auto_cleanup = auto_cleanup
680
857
 
681
858
  @property
@@ -748,20 +925,38 @@ class AtlasMapTags:
748
925
  """
749
926
  tag_paths = self._download_tag(tag_name, overwrite=overwrite)
750
927
  datum_ids = []
928
+ id_field_name = "_position_index"
929
+ if self._has_unique_id_field:
930
+ id_field_name = self.dataset.meta["unique_id_field"]
931
+
751
932
  for path in tag_paths:
752
933
  tb = feather.read_table(path)
753
934
  last_coord = path.name.split(".")[0]
754
- tile_path = path.with_name(last_coord + ".datum_id.feather")
755
- tile_tb = feather.read_table(tile_path).select([self.id_field])
935
+
936
+ # Get ID information - either from datum_id file or create synthetic IDs
937
+ if self._has_unique_id_field:
938
+ try:
939
+ tile_path = path.with_name(last_coord + ".datum_id.feather")
940
+ tile_tb = feather.read_table(tile_path).select([id_field_name])
941
+ except (FileNotFoundError, pa.ArrowInvalid):
942
+ # Create synthetic IDs if datum_id file not found
943
+ size = len(tb)
944
+ position_indices = [f"pos_{i}" for i in range(size)]
945
+ tile_tb = pa.Table.from_arrays([pa.array(position_indices)], names=[id_field_name])
946
+ else:
947
+ # Create synthetic IDs when no unique_id_field is available
948
+ size = len(tb)
949
+ position_indices = [f"pos_{i}" for i in range(size)]
950
+ tile_tb = pa.Table.from_arrays([pa.array(position_indices)], names=[id_field_name])
756
951
 
757
952
  if "all_set" in tb.column_names:
758
953
  if tb["all_set"][0].as_py() == True:
759
- datum_ids.extend(tile_tb[self.id_field].to_pylist())
954
+ datum_ids.extend(tile_tb[id_field_name].to_pylist())
760
955
  else:
761
956
  # filter on rows
762
957
  try:
763
- tb = tb.append_column(self.id_field, tile_tb[self.id_field])
764
- datum_ids.extend(tb.filter(pc.field("bitmask") == True)[self.id_field].to_pylist())
958
+ tb = tb.append_column(id_field_name, tile_tb[id_field_name])
959
+ datum_ids.extend(tb.filter(pc.field("bitmask") == True)[id_field_name].to_pylist())
765
960
  except Exception as e:
766
961
  raise Exception(f"Failed to fetch datums in tag. {e}")
767
962
  return datum_ids
@@ -849,7 +1044,6 @@ class AtlasMapData:
849
1044
  def __init__(self, projection: "AtlasProjection", fields=None): # type: ignore
850
1045
  self.projection = projection
851
1046
  self.dataset = projection.dataset
852
- self.id_field = self.projection.dataset.id_field
853
1047
  if fields is None:
854
1048
  # TODO: fall back on something more reliable here
855
1049
  self.fields = self.dataset.dataset_fields
@@ -858,6 +1052,9 @@ class AtlasMapData:
858
1052
  assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
859
1053
  self.fields = fields
860
1054
  self._tb = None
1055
+ self._has_unique_id_field = (
1056
+ "unique_id_field" in self.dataset.meta and self.dataset.meta["unique_id_field"] is not None
1057
+ )
861
1058
 
862
1059
  def _load_data(self, data_columns: List[Tuple[str, str]]):
863
1060
  """
@@ -871,24 +1068,67 @@ class AtlasMapData:
871
1068
  sidecars_to_load = set([sidecar for _, sidecar in data_columns if sidecar != "datum_id"])
872
1069
  logger.info("Loading data")
873
1070
  for key in tqdm(self.projection._manifest["key"].to_pylist()):
874
- # Use datum id as root table
875
- tb = feather.read_table(
876
- self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
877
- )
1071
+ # Use datum id as root table if available, otherwise create synthetic IDs
1072
+ if self._has_unique_id_field:
1073
+ try:
1074
+ datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
1075
+ tb = feather.read_table(datum_id_path, memory_map=True)
1076
+ except (FileNotFoundError, pa.ArrowInvalid):
1077
+ # Create a synthetic ID table
1078
+ # Using the first sidecar to determine table size
1079
+ first_sidecar = next(iter(sidecars_to_load)) if sidecars_to_load else ""
1080
+ tb = self._create_synthetic_id_table(key, first_sidecar)
1081
+ else:
1082
+ # Create synthetic IDs when no unique_id_field is available
1083
+ first_sidecar = next(iter(sidecars_to_load)) if sidecars_to_load else ""
1084
+ tb = self._create_synthetic_id_table(key, first_sidecar)
1085
+
878
1086
  for sidecar in sidecars_to_load:
879
1087
  path = self.projection.tile_destination
880
1088
  if sidecar == "":
881
1089
  path = path / Path(key).with_suffix(".feather")
882
1090
  else:
883
1091
  path = path / Path(key).with_suffix(f".{sidecar}.feather")
884
- carfile = feather.read_table(path, memory_map=True)
885
- for col in carfile.column_names:
886
- if col in self.fields:
887
- tb = tb.append_column(col, carfile[col])
1092
+
1093
+ try:
1094
+ carfile = feather.read_table(path, memory_map=True)
1095
+ for col in carfile.column_names:
1096
+ if col in self.fields:
1097
+ tb = tb.append_column(col, carfile[col])
1098
+ except (FileNotFoundError, pa.ArrowInvalid) as e:
1099
+ logger.warning(f"Error loading data for key {key}, sidecar {sidecar}: {e}")
1100
+ continue
1101
+
888
1102
  tbs.append(tb)
889
1103
 
1104
+ if not tbs:
1105
+ raise ValueError("No data could be loaded. Data files may be missing or corrupt.")
1106
+
890
1107
  self._tb = pa.concat_tables(tbs)
891
1108
 
1109
+ def _create_synthetic_id_table(self, key, sidecar):
1110
+ """
1111
+ Create a synthetic table with position indices when datum_id file isn't available
1112
+ or when unique_id_field isn't specified.
1113
+ """
1114
+ # Try to determine the size of the table by loading a sidecar file
1115
+ path = self.projection.tile_destination
1116
+ if sidecar == "":
1117
+ path = path / Path(key).with_suffix(".feather")
1118
+ else:
1119
+ path = path / Path(key).with_suffix(f".{sidecar}.feather")
1120
+
1121
+ try:
1122
+ sidecar_tb = feather.read_table(path, memory_map=True)
1123
+ size = len(sidecar_tb)
1124
+ # Create a table with position indices as IDs
1125
+ position_indices = [f"pos_{i}" for i in range(size)]
1126
+ return pa.Table.from_arrays([pa.array(position_indices)], names=["_position_index"])
1127
+ except Exception as e:
1128
+ logger.error(f"Failed to create synthetic IDs for {key}: {e}")
1129
+ # Return an empty table as fallback
1130
+ return pa.Table.from_arrays([pa.array([])], names=["_position_index"])
1131
+
892
1132
  def _download_data(self, fields: Optional[List[str]] = None) -> List[Tuple[str, str]]:
893
1133
  """
894
1134
  Downloads the feather tree for user uploaded data.
@@ -902,16 +1142,26 @@ class AtlasMapData:
902
1142
  logger.info("Downloading data")
903
1143
  self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
904
1144
 
905
- # Download specified or all sidecar fields + always download datum_id
1145
+ # Download specified or all sidecar fields + always download datum_id if available
906
1146
  data_columns_to_load = [
907
1147
  (str(field), str(sidecar))
908
1148
  for field, sidecar in self.projection._registered_columns
909
1149
  if field[0] != "_" and ((field in fields) or sidecar == "datum_id")
910
1150
  ]
911
1151
 
912
- # TODO: less confusing progress bar
913
- for sidecar in set([sidecar for _, sidecar in data_columns_to_load]):
1152
+ # Only download datum_id if we have a unique ID field
1153
+ if self._has_unique_id_field:
1154
+ try:
1155
+ self.projection._download_sidecar("datum_id")
1156
+ except ValueError as e:
1157
+ logger.warning(f"Failed to download datum_id files: {e}. Will use synthetic IDs instead.")
1158
+ self._has_unique_id_field = False
1159
+
1160
+ # Download all required sidecars for the fields
1161
+ sidecars_to_download = set(sidecar for _, sidecar in data_columns_to_load if sidecar != "datum_id")
1162
+ for sidecar in sidecars_to_download:
914
1163
  self.projection._download_sidecar(sidecar)
1164
+
915
1165
  return data_columns_to_load
916
1166
 
917
1167
  @property
@@ -115,15 +115,12 @@ class AtlasClass(object):
115
115
 
116
116
  return response.json()
117
117
 
118
- def _validate_map_data_inputs(self, colorable_fields, id_field, data_sample):
118
+ def _validate_map_data_inputs(self, colorable_fields, data_sample):
119
119
  """Validates inputs to map data calls."""
120
120
 
121
121
  if not isinstance(colorable_fields, list):
122
122
  raise ValueError("colorable_fields must be a list of fields")
123
123
 
124
- if id_field in colorable_fields:
125
- raise Exception(f"Cannot color by unique id field: {id_field}")
126
-
127
124
  for field in colorable_fields:
128
125
  if field not in data_sample:
129
126
  raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.")
@@ -274,14 +271,12 @@ class AtlasClass(object):
274
271
  """
275
272
  Private method. validates upload data against the dataset arrow schema, and associated other checks.
276
273
 
277
- 1. If unique_id_field is specified, validates that each datum has that field. If not, adds it and then notifies the user that it was added.
278
-
279
274
  Args:
280
275
  data: an arrow table.
281
276
  project: the atlas dataset you are validating the data for.
282
277
 
283
278
  Returns:
284
-
279
+ Validated pyarrow table.
285
280
  """
286
281
  if not isinstance(data, pa.Table):
287
282
  raise Exception("Invalid data type for upload: {}".format(type(data)))
@@ -295,8 +290,32 @@ class AtlasClass(object):
295
290
  msg = "Must include embeddings in embedding dataset upload."
296
291
  raise ValueError(msg)
297
292
 
298
- if project.id_field not in data.column_names:
299
- raise ValueError(f"Data must contain the ID column `{project.id_field}`")
293
+ # Check and validate ID field if specified
294
+ if "unique_id_field" in project.meta and project.meta["unique_id_field"] is not None:
295
+ id_field = project.meta["unique_id_field"]
296
+
297
+ # Check if ID field exists in data
298
+ if id_field not in data.column_names:
299
+ raise ValueError(
300
+ f"Data must contain the ID column `{id_field}` as specified in dataset's unique_id_field"
301
+ )
302
+
303
+ # Check for null values in ID field
304
+ if data[id_field].null_count > 0:
305
+ raise ValueError(
306
+ f"As your unique id field, {id_field} must not contain null values, but {data[id_field].null_count} found."
307
+ )
308
+
309
+ # Check ID field length (36 characters max)
310
+ if pa.types.is_string(data[id_field].type):
311
+ # Use a safer alternative to check string length
312
+ utf8_length_values = pc.utf8_length(data[id_field]) # type: ignore
313
+ max_length_scalar = pc.max(utf8_length_values) # type: ignore
314
+ max_length = max_length_scalar.as_py()
315
+ if max_length > 36:
316
+ raise ValueError(
317
+ f"The id_field contains values greater than 36 characters. Atlas does not support id_fields longer than 36 characters."
318
+ )
300
319
 
301
320
  seen = set()
302
321
  for col in data.column_names:
@@ -313,11 +332,6 @@ class AtlasClass(object):
313
332
  # filling in nulls, etc.
314
333
  reformatted = {}
315
334
 
316
- if data[project.id_field].null_count > 0:
317
- raise ValueError(
318
- f"{project.id_field} must not contain null values, but {data[project.id_field].null_count} found."
319
- )
320
-
321
335
  assert project.schema is not None, "Project schema not found."
322
336
 
323
337
  for field in project.schema:
@@ -335,10 +349,15 @@ class AtlasClass(object):
335
349
  f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version."
336
350
  )
337
351
  reformatted[field.name] = pc.fill_null(reformatted[field.name], "null")
338
- if pa.compute.any(pa.compute.equal(pa.compute.binary_length(reformatted[field.name]), 0)): # type: ignore
339
- mask = pa.compute.equal(pa.compute.binary_length(reformatted[field.name]), 0).combine_chunks() # type: ignore
340
- assert pa.types.is_boolean(mask.type) # type: ignore
341
- reformatted[field.name] = pa.compute.replace_with_mask(reformatted[field.name], mask, "null") # type: ignore
352
+
353
+ # Check for empty strings and replace with "null"
354
+ # Separate the operations for better type checking
355
+ binary_length_values = pc.binary_length(reformatted[field.name]) # type: ignore
356
+ has_empty_strings = pc.equal(binary_length_values, 0) # type: ignore
357
+ if pc.any(has_empty_strings).as_py(): # type: ignore
358
+ mask = has_empty_strings.combine_chunks()
359
+ assert pa.types.is_boolean(mask.type)
360
+ reformatted[field.name] = pc.replace_with_mask(reformatted[field.name], mask, "null") # type: ignore
342
361
  for field in data.schema:
343
362
  if not field.name in reformatted:
344
363
  if field.name == "_embeddings":
@@ -350,27 +369,12 @@ class AtlasClass(object):
350
369
  if project.meta["insert_update_delete_lock"]:
351
370
  raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.")
352
371
 
353
- # The following two conditions should never occur given the above, but just in case...
354
- assert project.id_field in data.column_names, f"Upload does not contain your specified id_field"
355
-
356
- if not pa.types.is_string(data[project.id_field].type):
357
- logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}")
358
- data = data.drop([project.id_field]).append_column(
359
- project.id_field, data[project.id_field].cast(pa.string())
360
- )
361
-
362
372
  for key in data.column_names:
363
373
  if key.startswith("_"):
364
374
  if key == "_embeddings" or key == "_blob_hash":
365
375
  continue
366
376
  raise ValueError("Metadata fields cannot start with _")
367
- if pa.compute.max(pa.compute.utf8_length(data[project.id_field])).as_py() > 36: # type: ignore
368
- first_match = data.filter(
369
- pa.compute.greater(pa.compute.utf8_length(data[project.id_field]), 36) # type: ignore
370
- ).to_pylist()[0][project.id_field]
371
- raise ValueError(
372
- f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters."
373
- )
377
+
374
378
  return data
375
379
 
376
380
  def _get_organization(self, organization_slug=None, organization_id=None) -> Tuple[str, str]:
@@ -696,36 +700,6 @@ class AtlasProjection:
696
700
  def tile_destination(self):
697
701
  return Path("~/.nomic/cache", self.id).expanduser()
698
702
 
699
- @property
700
- def datum_id_field(self):
701
- return self.dataset.meta["unique_id_field"]
702
-
703
- def _get_atoms(self, ids: List[str]) -> List[Dict]:
704
- """
705
- Retrieves atoms by id
706
-
707
- Args:
708
- ids: list of atom ids
709
-
710
- Returns:
711
- A dictionary containing the resulting atoms, keyed by atom id.
712
-
713
- """
714
-
715
- if not isinstance(ids, list):
716
- raise ValueError("You must specify a list of ids when getting data.")
717
-
718
- response = requests.post(
719
- self.dataset.atlas_api_path + "/v1/project/atoms/get",
720
- headers=self.dataset.header,
721
- json={"project_id": self.dataset.id, "index_id": self.atlas_index_id, "atom_ids": ids},
722
- )
723
-
724
- if response.status_code == 200:
725
- return response.json()["atoms"]
726
- else:
727
- raise Exception(response.text)
728
-
729
703
 
730
704
  class AtlasDataStream(AtlasClass):
731
705
  def __init__(self, name: Optional[str] = "contrastors"):
@@ -766,7 +740,7 @@ class AtlasDataset(AtlasClass):
766
740
 
767
741
  * **identifier** - The dataset identifier in the form `dataset` or `organization/dataset`. If no organization is passed, the organization tied to the API key you logged in to Nomic with will be used.
768
742
  * **description** - A description for the dataset.
769
- * **unique_id_field** - The field that uniquely identifies each data point.
743
+ * **unique_id_field** - A field that uniquely identifies each data point.
770
744
  * **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
771
745
  * **dataset_id** - An alternative way to load a dataset is by passing the dataset_id directly. This only works if a dataset exists.
772
746
  """
@@ -805,13 +779,6 @@ class AtlasDataset(AtlasClass):
805
779
  dataset_id = dataset["id"]
806
780
 
807
781
  if dataset_id is None: # if there is no existing project, make a new one.
808
- if unique_id_field is None: # if not all parameters are specified, we weren't trying to make a project
809
- raise ValueError(f"Dataset `{identifier}` does not exist.")
810
-
811
- # if modality is None:
812
- # raise ValueError("You must specify a modality when creating a new dataset.")
813
- #
814
- # assert modality in ['text', 'embedding'], "Modality must be either `text` or `embedding`"
815
782
  assert identifier is not None
816
783
 
817
784
  dataset_id = self._create_project(
@@ -840,7 +807,7 @@ class AtlasDataset(AtlasClass):
840
807
  self,
841
808
  identifier: str,
842
809
  description: Optional[str],
843
- unique_id_field: str,
810
+ unique_id_field: Optional[str] = None,
844
811
  is_public: bool = True,
845
812
  ):
846
813
  """
@@ -852,7 +819,7 @@ class AtlasDataset(AtlasClass):
852
819
 
853
820
  * **identifier** - The identifier for the dataset.
854
821
  * **description** - A description for the dataset.
855
- * **unique_id_field** - The field that uniquely identifies each datum. If a datum does not contain this field, it will be added and assigned a random unique ID.
822
+ * **unique_id_field** - A field that uniquely identifies each data point.
856
823
  * **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
857
824
 
858
825
  **Returns:** project_id on success.
@@ -865,15 +832,6 @@ class AtlasDataset(AtlasClass):
865
832
  if "/" in identifier:
866
833
  org_name = identifier.split("/")[0]
867
834
  logger.info(f"Organization name: `{org_name}`")
868
- # supported_modalities = ['text', 'embedding']
869
- # if modality not in supported_modalities:
870
- # msg = 'Tried to create dataset with modality: {}, but Atlas only supports: {}'.format(
871
- # modality, supported_modalities
872
- # )
873
- # raise ValueError(msg)
874
-
875
- if unique_id_field is None:
876
- raise ValueError("You must specify a unique id field")
877
835
  if description is None:
878
836
  description = ""
879
837
  response = requests.post(
@@ -884,7 +842,6 @@ class AtlasDataset(AtlasClass):
884
842
  "project_name": project_slug,
885
843
  "description": description,
886
844
  "unique_id_field": unique_id_field,
887
- # 'modality': modality,
888
845
  "is_public": is_public,
889
846
  },
890
847
  )
@@ -939,11 +896,12 @@ class AtlasDataset(AtlasClass):
939
896
 
940
897
  @property
941
898
  def id(self) -> str:
942
- """The UUID of the dataset."""
899
+ """The ID of the dataset."""
943
900
  return self.meta["id"]
944
901
 
945
902
  @property
946
903
  def id_field(self) -> str:
904
+ """The unique_id_field of the dataset."""
947
905
  return self.meta["unique_id_field"]
948
906
 
949
907
  @property
@@ -1147,7 +1105,7 @@ class AtlasDataset(AtlasClass):
1147
1105
  colorable_fields = []
1148
1106
 
1149
1107
  for field in self.dataset_fields:
1150
- if field not in [self.id_field, indexed_field] and not field.startswith("_"):
1108
+ if field not in [indexed_field] and not field.startswith("_"):
1151
1109
  colorable_fields.append(field)
1152
1110
 
1153
1111
  build_template = {}
@@ -1410,98 +1368,126 @@ class AtlasDataset(AtlasClass):
1410
1368
  Uploads blobs to the server and associates them with the data.
1411
1369
  Blobs must reference objects stored locally
1412
1370
  """
1371
+ data_as_table: pa.Table
1413
1372
  if isinstance(data, DataFrame):
1414
- data = pa.Table.from_pandas(data)
1373
+ data_as_table = pa.Table.from_pandas(data)
1415
1374
  elif isinstance(data, list):
1416
- data = pa.Table.from_pylist(data)
1417
- elif not isinstance(data, pa.Table):
1375
+ data_as_table = pa.Table.from_pylist(data)
1376
+ elif isinstance(data, pa.Table):
1377
+ data_as_table = data
1378
+ else:
1418
1379
  raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.")
1419
1380
 
1420
- blob_upload_endpoint = "/v1/project/data/add/blobs"
1421
-
1422
- # uploda batch of blobs
1423
- # return hash of blob
1424
- # add hash to data as _blob_hash
1425
- # set indexed_field to _blob_hash
1426
- # call _add_data
1427
-
1428
- # Cast self id field to string for merged data lower down on function
1429
- data = data.set_column( # type: ignore
1430
- data.schema.get_field_index(self.id_field), self.id_field, pc.cast(data[self.id_field], pa.string()) # type: ignore
1431
- )
1381
+ # Compute dataset length
1382
+ data_length = len(data_as_table)
1383
+ if data_length != len(blobs):
1384
+ raise ValueError(f"Number of data points ({data_length}) must match number of blobs ({len(blobs)})")
1432
1385
 
1433
- ids = data[self.id_field].to_pylist() # type: ignore
1434
- if not isinstance(ids[0], str):
1435
- ids = [str(uuid) for uuid in ids]
1436
-
1437
- # TODO: add support for other modalities
1438
- images = []
1439
- for uuid, blob in tqdm(zip(ids, blobs), total=len(ids), desc="Loading images"):
1440
- if (isinstance(blob, str) or isinstance(blob, Path)) and os.path.exists(blob):
1441
- # Auto resize to max 512x512
1442
- image = Image.open(blob)
1443
- image = image.convert("RGB")
1386
+ TEMP_ID_COLUMN = "_nomic_internal_temp_id"
1387
+ temp_id_values = [str(i) for i in range(data_length)]
1388
+ data_as_table = data_as_table.append_column(TEMP_ID_COLUMN, pa.array(temp_id_values, type=pa.string()))
1389
+ blob_upload_endpoint = "/v1/project/data/add/blobs"
1390
+ actual_temp_ids = data_as_table[TEMP_ID_COLUMN].to_pylist()
1391
+ images = [] # List of (temp_id, image_bytes)
1392
+ for i in tqdm(range(data_length), desc="Processing images"):
1393
+ current_temp_id = actual_temp_ids[i]
1394
+ blob_item = blobs[i]
1395
+
1396
+ processed_blob_value = None
1397
+ if (isinstance(blob_item, str) or isinstance(blob_item, Path)) and os.path.exists(blob_item):
1398
+ image = Image.open(blob_item).convert("RGB")
1444
1399
  if image.height > 512 or image.width > 512:
1445
1400
  image = image.resize((512, 512))
1446
1401
  buffered = BytesIO()
1447
1402
  image.save(buffered, format="JPEG")
1448
- images.append((uuid, buffered.getvalue()))
1449
- elif isinstance(blob, bytes):
1450
- images.append((uuid, blob))
1451
- elif isinstance(blob, Image.Image):
1452
- blob = blob.convert("RGB") # type: ignore
1453
- if blob.height > 512 or blob.width > 512:
1454
- blob = blob.resize((512, 512))
1403
+ processed_blob_value = buffered.getvalue()
1404
+ elif isinstance(blob_item, bytes):
1405
+ processed_blob_value = blob_item
1406
+ elif isinstance(blob_item, Image.Image):
1407
+ img_pil = blob_item.convert("RGB") # Ensure it's PIL Image for methods
1408
+ if img_pil.height > 512 or img_pil.width > 512:
1409
+ img_pil = img_pil.resize((512, 512))
1455
1410
  buffered = BytesIO()
1456
- blob.save(buffered, format="JPEG")
1457
- images.append((uuid, buffered.getvalue()))
1411
+ img_pil.save(buffered, format="JPEG")
1412
+ processed_blob_value = buffered.getvalue()
1458
1413
  else:
1459
- raise ValueError(f"Invalid blob type for {uuid}. Must be a path to an image, bytes, or PIL Image.")
1414
+ raise ValueError(
1415
+ f"Invalid blob type for item at index {i} (temp_id: {current_temp_id}). Must be a path, bytes, or PIL Image. Got: {type(blob_item)}"
1416
+ )
1417
+
1418
+ if processed_blob_value is not None:
1419
+ images.append((current_temp_id, processed_blob_value))
1460
1420
 
1461
1421
  batch_size = 40
1462
- num_workers = 10
1422
+ num_workers = 2
1423
+
1424
+ def send_request(batch_start_index):
1425
+ image_batch = images[batch_start_index : batch_start_index + batch_size]
1426
+ temp_ids_in_batch = [item_id for item_id, _ in image_batch]
1427
+ blobs_for_api = [("blobs", blob_val) for _, blob_val in image_batch]
1463
1428
 
1464
- def send_request(i):
1465
- image_batch = images[i : i + batch_size]
1466
- ids = [uuid for uuid, _ in image_batch]
1467
- blobs = [("blobs", blob) for _, blob in image_batch]
1468
1429
  response = requests.post(
1469
1430
  self.atlas_api_path + blob_upload_endpoint,
1470
1431
  headers=self.header,
1471
- data={"dataset_id": self.id},
1472
- files=blobs,
1432
+ data={"dataset_id": self.id}, # self.id is project_id
1433
+ files=blobs_for_api,
1473
1434
  )
1474
1435
  if response.status_code != 200:
1475
- raise Exception(response.text)
1476
- return {uuid: blob_hash for uuid, blob_hash in zip(ids, response.json()["hashes"])}
1436
+ failed_ids_sample = temp_ids_in_batch[:5]
1437
+ logger.error(
1438
+ f"Blob upload request failed for batch starting with temp_ids: {failed_ids_sample}. Status: {response.status_code}, Response: {response.text}"
1439
+ )
1440
+ raise Exception(f"Blob upload failed: {response.text}")
1441
+ return {temp_id: blob_hash for temp_id, blob_hash in zip(temp_ids_in_batch, response.json()["hashes"])}
1477
1442
 
1478
- # if this method is being called internally, we pass a global progress bar
1479
- if pbar is None:
1480
- pbar = tqdm(total=len(data), desc="Uploading blobs to Atlas")
1443
+ upload_pbar = pbar # Use passed-in pbar if available
1444
+ close_upload_pbar_locally = False
1445
+ if upload_pbar is None:
1446
+ upload_pbar = tqdm(total=len(images), desc="Uploading blobs to Atlas")
1447
+ close_upload_pbar_locally = True
1481
1448
 
1482
- hash_schema = pa.schema([(self.id_field, pa.string()), ("_blob_hash", pa.string())])
1483
- returned_ids = []
1449
+ returned_temp_ids = []
1484
1450
  returned_hashes = []
1451
+ succeeded_uploads = 0
1485
1452
 
1486
- succeeded = 0
1487
1453
  with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
1488
- futures = {executor.submit(send_request, i): i for i in range(0, len(data), batch_size)}
1454
+ futures = {executor.submit(send_request, i): i for i in range(0, len(images), batch_size)}
1489
1455
 
1490
1456
  for future in concurrent.futures.as_completed(futures):
1491
- response = future.result()
1492
- # add hash to data as _blob_hash
1493
- for uuid, blob_hash in response.items():
1494
- returned_ids.append(uuid)
1495
- returned_hashes.append(blob_hash)
1496
-
1497
- # A successful upload.
1498
- succeeded += len(response)
1499
- pbar.update(len(response))
1457
+ try:
1458
+ response_dict = future.result() # This is {temp_id: blob_hash}
1459
+ for temp_id, blob_hash_val in response_dict.items():
1460
+ returned_temp_ids.append(temp_id)
1461
+ returned_hashes.append(blob_hash_val)
1462
+ succeeded_uploads += len(response_dict)
1463
+ if upload_pbar:
1464
+ upload_pbar.update(len(response_dict))
1465
+ except Exception as e:
1466
+ logger.error(f"An error occurred during blob upload processing for a batch: {e}")
1467
+ # Optionally, collect failed batch info here if needed for partial success
1468
+
1469
+ if close_upload_pbar_locally and upload_pbar:
1470
+ upload_pbar.close()
1471
+
1472
+ hash_schema = pa.schema([(TEMP_ID_COLUMN, pa.string()), ("_blob_hash", pa.string())])
1473
+
1474
+ merged_data_as_table: pa.Table
1475
+ if succeeded_uploads > 0: # Only create hash_tb if there are successful uploads
1476
+ hash_tb = pa.Table.from_pydict(
1477
+ {TEMP_ID_COLUMN: returned_temp_ids, "_blob_hash": returned_hashes}, schema=hash_schema
1478
+ )
1479
+ merged_data_as_table = data_as_table.join(right_table=hash_tb, keys=TEMP_ID_COLUMN, join_type="left outer")
1480
+ else: # No successful uploads, so no hashes to merge, but keep original data structure
1481
+ # Need to ensure _blob_hash column is added with nulls, and id_field is present
1482
+ if "_blob_hash" not in data_as_table.column_names:
1483
+ data_as_table = data_as_table.append_column(
1484
+ "_blob_hash", pa.nulls(data_as_table.num_rows, type=pa.string())
1485
+ )
1486
+ merged_data_as_table = data_as_table
1500
1487
 
1501
- hash_tb = pa.Table.from_pydict({self.id_field: returned_ids, "_blob_hash": returned_hashes}, schema=hash_schema)
1502
- merged_data = data.join(right_table=hash_tb, keys=self.id_field) # type: ignore
1488
+ merged_data_as_table = merged_data_as_table.drop_columns([TEMP_ID_COLUMN])
1503
1489
 
1504
- self._add_data(merged_data, pbar=pbar)
1490
+ self._add_data(merged_data_as_table, pbar=pbar) # Pass original pbar argument
1505
1491
 
1506
1492
  def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None):
1507
1493
  """
@@ -1580,12 +1566,8 @@ class AtlasDataset(AtlasClass):
1580
1566
  None
1581
1567
  """
1582
1568
 
1583
- # Exactly 10 upload workers at a time.
1584
-
1585
- num_workers = 10
1586
-
1569
+ num_workers = 2
1587
1570
  # Each worker currently is too slow beyond a shard_size of 10000
1588
-
1589
1571
  # The heuristic here is: Never let shards be more than 10,000 items,
1590
1572
  # OR more than 16MB uncompressed. Whichever is smaller.
1591
1573
 
@@ -1701,7 +1683,7 @@ class AtlasDataset(AtlasClass):
1701
1683
  else:
1702
1684
  logger.info("Upload succeeded.")
1703
1685
 
1704
- def update_maps(self, data: List[Dict], embeddings: Optional[np.ndarray] = None, num_workers: int = 10):
1686
+ def update_maps(self, data: List[Dict], embeddings: Optional[np.ndarray] = None, num_workers: int = 2):
1705
1687
  """
1706
1688
  Utility method to update a project's maps by adding the given data.
1707
1689
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.5.0
3
+ Version: 3.5.2
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -23,7 +23,7 @@ with open("README.md") as f:
23
23
 
24
24
  setup(
25
25
  name="nomic",
26
- version="3.5.0",
26
+ version="3.5.2",
27
27
  url="https://github.com/nomic-ai/nomic",
28
28
  description=description,
29
29
  long_description=long_description,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes