nomic 3.0.32__tar.gz → 3.0.34__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nomic might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.0.32
3
+ Version: 3.0.34
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -88,7 +88,7 @@ class NomicTopicOptions(BaseModel):
88
88
  """
89
89
 
90
90
  build_topic_model: bool = True
91
- topic_label_field: Optional[str] = Field(default=None, alias="community_description_target_field")
91
+ topic_label_field: Optional[str] = Field(default=None)
92
92
  cluster_method: str = "fast"
93
93
  enforce_topic_hierarchy: bool = False
94
94
 
@@ -1,11 +1,10 @@
1
1
  import base64
2
2
  import io
3
3
  import json
4
- import os
5
4
  from collections import defaultdict
6
5
  from datetime import datetime
7
6
  from pathlib import Path
8
- from typing import Dict, Iterable, List, Optional, Tuple
7
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
9
8
 
10
9
  import numpy as np
11
10
  import pandas as pd
@@ -16,9 +15,6 @@ from pyarrow import compute as pc
16
15
  from pyarrow import feather
17
16
  from tqdm import tqdm
18
17
 
19
- from .settings import EMBEDDING_PAGINATION_LIMIT
20
- from .utils import download_feather
21
-
22
18
 
23
19
  class AtlasMapDuplicates:
24
20
  """
@@ -30,22 +26,57 @@ class AtlasMapDuplicates:
30
26
  def __init__(self, projection: "AtlasProjection"): # type: ignore
31
27
  self.projection = projection
32
28
  self.id_field = self.projection.dataset.id_field
33
- try:
34
- duplicate_fields = [
35
- field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field
36
- ]
37
- cluster_fields = [field for field in projection._fetch_tiles().column_names if "_cluster" in field]
38
- assert len(duplicate_fields) > 0, "Duplicate detection has not yet been run on this map."
39
- self.duplicate_field = duplicate_fields[0]
40
- self.cluster_field = cluster_fields[0]
41
- self._tb: pa.Table = projection._fetch_tiles().select(
42
- [self.id_field, self.duplicate_field, self.cluster_field]
29
+
30
+ duplicate_columns = [
31
+ (field, sidecar)
32
+ for field, sidecar in self.projection._registered_columns
33
+ if field.startswith("_duplicate_class")
34
+ ]
35
+ cluster_columns = [
36
+ (field, sidecar) for field, sidecar in self.projection._registered_columns if field.startswith("_cluster")
37
+ ]
38
+
39
+ assert len(duplicate_columns) > 0, "Duplicate detection has not yet been run on this map."
40
+
41
+ self._duplicate_column = duplicate_columns[0]
42
+ self._cluster_column = cluster_columns[0]
43
+ self._tb = None
44
+
45
+ def _load_duplicates(self):
46
+ """
47
+ Loads duplicates from the feather tree.
48
+ """
49
+ tbs = []
50
+ duplicate_sidecar = self._duplicate_column[1]
51
+ self.duplicate_field = self._duplicate_column[0].lstrip("_")
52
+ self.cluster_field = self._cluster_column[0].lstrip("_")
53
+ logger.info("Loading duplicates")
54
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
55
+ # Use datum id as root table
56
+ tb = feather.read_table(
57
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
43
58
  )
44
- except pa.lib.ArrowInvalid as e: # type: ignore
45
- raise ValueError("Duplicate detection has not yet been run on this map.")
46
- self.duplicate_field = self.duplicate_field.lstrip("_")
47
- self.cluster_field = self.cluster_field.lstrip("_")
48
- self._tb = self._tb.rename_columns([self.id_field, self.duplicate_field, self.cluster_field])
59
+ path = self.projection.tile_destination
60
+
61
+ if duplicate_sidecar == "":
62
+ path = path / Path(key).with_suffix(".feather")
63
+ else:
64
+ path = path / Path(key).with_suffix(f".{duplicate_sidecar}.feather")
65
+
66
+ duplicate_tb = feather.read_table(path, memory_map=True)
67
+ for field in (self._duplicate_column[0], self._cluster_column[0]):
68
+ tb = tb.append_column(field, duplicate_tb[field])
69
+ tbs.append(tb)
70
+ self._tb = pa.concat_tables(tbs).rename_columns([self.id_field, self.duplicate_field, self.cluster_field])
71
+
72
+ def _download_duplicates(self):
73
+ """
74
+ Downloads the feather tree for duplicates.
75
+ """
76
+ logger.info("Downloading duplicates")
77
+ self.projection._download_sidecar("datum_id", overwrite=False)
78
+ assert self._cluster_column[1] == self._duplicate_column[1], "Cluster and duplicate should be in same sidecar"
79
+ self.projection._download_sidecar(self._duplicate_column[1], overwrite=False)
49
80
 
50
81
  @property
51
82
  def df(self) -> pd.DataFrame:
@@ -61,6 +92,10 @@ class AtlasMapDuplicates:
61
92
  This table is memmapped from the underlying files and is the most efficient way to
62
93
  access duplicate information.
63
94
  """
95
+ if isinstance(self._tb, pa.Table):
96
+ return self._tb
97
+ self._download_duplicates()
98
+ self._load_duplicates()
64
99
  return self._tb
65
100
 
66
101
  def deletion_candidates(self) -> List[str]:
@@ -93,35 +128,69 @@ class AtlasMapTopics:
93
128
  self.id_field = self.projection.dataset.id_field
94
129
  self._metadata = None
95
130
  self._hierarchy = None
131
+ self._topic_columns = [
132
+ column for column in self.projection._registered_columns if column[0].startswith("_topic_depth_")
133
+ ]
134
+ assert len(self._topic_columns) > 0, "Topic modeling has not yet been run on this map."
135
+ self.depth = len(self._topic_columns)
136
+ self._tb = None
96
137
 
97
- try:
98
- logger.info("Downloading topics")
99
- self._tb: pa.Table = projection._fetch_tiles()
100
- topic_fields = [column for column in self._tb.column_names if column.startswith("_topic_depth_")]
101
- self.depth = len(topic_fields)
102
-
103
- # If using topic ids, fetch topic labels
104
- if "int" in topic_fields[0]:
105
- new_topic_fields = []
106
- label_df = self.metadata[["topic_id", "depth", "topic_short_description"]]
107
- for d in range(1, self.depth + 1):
138
+ def _load_topics(self):
139
+ """
140
+ Loads topics from the feather tree.
141
+ """
142
+ integer_topics = False
143
+ # pd.Series to match pd typing
144
+ label_df: Optional[Union[pd.DataFrame, pd.Series]] = None
145
+ if "int" in self._topic_columns[0][0]:
146
+ integer_topics = True
147
+ label_df = self.metadata[["topic_id", "depth", "topic_short_description"]]
148
+ tbs = []
149
+ # Should just be one sidecar
150
+ topic_sidecar = set([sidecar for _, sidecar in self._topic_columns]).pop()
151
+ logger.info("Loading topics")
152
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
153
+ # Use datum id as root table
154
+ tb = feather.read_table(
155
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
156
+ )
157
+ path = self.projection.tile_destination
158
+ if topic_sidecar == "":
159
+ path = path / Path(key).with_suffix(".feather")
160
+ else:
161
+ path = path / Path(key).with_suffix(f".{topic_sidecar}.feather")
162
+
163
+ topic_tb = feather.read_table(path, memory_map=True)
164
+ # Do this in depth order
165
+ for d in range(1, self.depth + 1):
166
+ column = f"_topic_depth_{d}"
167
+ if integer_topics:
108
168
  column = f"_topic_depth_{d}_int"
109
- topic_ids_to_label = self._tb[column].to_pandas().rename("topic_id")
169
+ topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
170
+ assert label_df is not None
110
171
  topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
111
172
  topic_ids_to_label, on="topic_id", how="right"
112
173
  )
113
174
  new_column = f"_topic_depth_{d}"
114
- self._tb = self._tb.append_column(
175
+ tb = tb.append_column(
115
176
  new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
116
177
  )
117
- new_topic_fields.append(new_column)
118
- topic_fields = new_topic_fields
178
+ else:
179
+ tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
180
+ tbs.append(tb)
119
181
 
120
- renamed_fields = [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
121
- self._tb = self._tb.select([self.id_field] + topic_fields).rename_columns([self.id_field] + renamed_fields)
182
+ renamed_columns = [self.id_field] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
183
+ self._tb = pa.concat_tables(tbs).rename_columns(renamed_columns)
122
184
 
123
- except pa.lib.ArrowInvalid as e: # type: ignore
124
- raise ValueError("Topic modeling has not yet been run on this map.")
185
+ def _download_topics(self):
186
+ """
187
+ Downloads the feather tree for topics.
188
+ """
189
+ logger.info("Downloading topics")
190
+ self.projection._download_sidecar("datum_id", overwrite=False)
191
+ topic_sidecars = set([sidecar for _, sidecar in self._topic_columns])
192
+ assert len(topic_sidecars) == 1, "Multiple topic sidecars found."
193
+ self.projection._download_sidecar(topic_sidecars.pop(), overwrite=False)
125
194
 
126
195
  @property
127
196
  def df(self) -> pd.DataFrame:
@@ -137,6 +206,10 @@ class AtlasMapTopics:
137
206
  This table is memmapped from the underlying files and is the most efficient way to
138
207
  access topic information.
139
208
  """
209
+ if isinstance(self._tb, pa.Table):
210
+ return self._tb
211
+ self._download_topics()
212
+ self._load_topics()
140
213
  return self._tb
141
214
 
142
215
  @property
@@ -260,8 +333,8 @@ class AtlasMapTopics:
260
333
  A list of `{topic, count}` dictionaries, sorted from largest count to smallest count.
261
334
  """
262
335
  data = AtlasMapData(self.projection, fields=[time_field])
263
- time_data = data._tb.select([self.id_field, time_field])
264
- merged_tb = self._tb.join(time_data, self.id_field, join_type="inner").combine_chunks()
336
+ time_data = data.tb.select([self.id_field, time_field])
337
+ merged_tb = self.tb.join(time_data, self.id_field, join_type="inner").combine_chunks()
265
338
 
266
339
  del time_data # free up memory
267
340
 
@@ -376,8 +449,8 @@ class AtlasMapEmbeddings:
376
449
  def __init__(self, projection: "AtlasProjection"): # type: ignore
377
450
  self.projection = projection
378
451
  self.id_field = self.projection.dataset.id_field
379
- self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, "x", "y"])
380
452
  self.dataset = projection.dataset
453
+ self._tb: pa.Table = None
381
454
  self._latent = None
382
455
 
383
456
  @property
@@ -398,6 +471,31 @@ class AtlasMapEmbeddings:
398
471
 
399
472
  Does not include high-dimensional embeddings.
400
473
  """
474
+ if isinstance(self._tb, pa.Table):
475
+ return self._tb
476
+
477
+ self._download_projected()
478
+
479
+ logger.info("Loading projected embeddings")
480
+
481
+ tbs = []
482
+ coord_sidecar = self.projection._get_sidecar_from_field("x")
483
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
484
+ # Use datum id as root table
485
+ tb = feather.read_table(
486
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
487
+ )
488
+ path = self.projection.tile_destination
489
+ if coord_sidecar == "":
490
+ path = path / Path(key).with_suffix(".feather")
491
+ else:
492
+ path = path / Path(key).with_suffix(f".{coord_sidecar}.feather")
493
+ carfile = feather.read_table(path, memory_map=True)
494
+ for col in carfile.column_names:
495
+ if col in ["x", "y"]:
496
+ tb = tb.append_column(col, carfile[col])
497
+ tbs.append(tb)
498
+ self._tb = pa.concat_tables(tbs)
401
499
  return self._tb
402
500
 
403
501
  @property
@@ -426,38 +524,40 @@ class AtlasMapEmbeddings:
426
524
  downloaded_files_in_tile_order = self._download_latent()
427
525
  assert len(downloaded_files_in_tile_order) > 0, "No embeddings found for this map."
428
526
  all_embeddings = []
429
-
430
- for path in downloaded_files_in_tile_order:
431
- # Should there be more than 10, we need to sort by int values, not string values
527
+ logger.info("Loading latent embeddings")
528
+ for path in tqdm(downloaded_files_in_tile_order):
432
529
  tb = feather.read_table(path, memory_map=True)
433
530
  dims = tb["_embeddings"].type.list_size
434
531
  all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
435
532
  return np.vstack(all_embeddings)
436
533
 
534
+ def _download_projected(self) -> List[Path]:
535
+ """
536
+ Downloads the feather tree for projection coordinates.
537
+ """
538
+ logger.info("Downloading projected embeddings")
539
+ # Note that y coord should be in same sidecar
540
+ coord_sidecar = self.projection._get_sidecar_from_field("x")
541
+ self.projection._download_sidecar("datum_id", overwrite=False)
542
+ return self.projection._download_sidecar(coord_sidecar, overwrite=False)
543
+
437
544
  def _download_latent(self) -> List[Path]:
438
545
  """
439
546
  Downloads the feather tree for embeddings.
440
547
  Returns the path to downloaded embeddings.
441
548
  """
442
549
  # TODO: Is size of the embedding files (several hundreds of MBs) going to be a problem here?
443
- self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
444
- root_url = Path(
445
- f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
446
- )
447
-
448
- registered_sidecar_names = [sidecar[1] for sidecar in self.projection._registered_sidecars()]
449
- assert "embeddings" in registered_sidecar_names, "Embeddings not found in sidecars."
450
-
451
- downloaded_files_in_tile_order = []
452
- logger.info("Downloading latent embeddings...")
453
- all_quads = list(self.projection._tiles_in_order())
454
- for quad in tqdm(all_quads):
455
- path = quad.with_suffix(".embeddings.feather")
456
- # WARNING: Potentially large data request here
457
- quadtree_loc = Path(*path.parts[-3:])
458
- download_feather(root_url / quadtree_loc, path, headers=self.dataset.header, overwrite=False)
459
- downloaded_files_in_tile_order.append(path)
460
- return downloaded_files_in_tile_order
550
+ logger.info("Downloading latent embeddings")
551
+ embedding_sidecar = None
552
+ for field, sidecar in self.projection._registered_columns:
553
+ # NOTE: be _embeddings or _embedding
554
+ if field == "_embeddings":
555
+ embedding_sidecar = sidecar
556
+ break
557
+
558
+ if embedding_sidecar is None:
559
+ raise ValueError("No embeddings found for this map.")
560
+ return self.projection._download_sidecar(embedding_sidecar, overwrite=False)
461
561
 
462
562
  def vector_search(
463
563
  self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5
@@ -571,12 +671,15 @@ class AtlasMapTags:
571
671
  self.projection = projection
572
672
  self.dataset = projection.dataset
573
673
  self.id_field = self.projection.dataset.id_field
574
- # Pre-fetch tiles first upon initialization
575
- self.projection._fetch_tiles(overwrite=False)
674
+ # Pre-fetch datum ids first upon initialization
675
+ try:
676
+ self.projection._download_sidecar("datum_id")
677
+ except Exception:
678
+ raise ValueError("Failed to fetch datum ids which is required to load tags.")
576
679
  self.auto_cleanup = auto_cleanup
577
680
 
578
681
  @property
579
- def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame:
682
+ def df(self, overwrite: bool = False) -> pd.DataFrame:
580
683
  """
581
684
  Pandas DataFrame mapping each data point to its tags.
582
685
  """
@@ -587,16 +690,13 @@ class AtlasMapTags:
587
690
  for tag in tags:
588
691
  self._download_tag(tag["tag_name"], overwrite=overwrite)
589
692
  tbs = []
590
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
591
- for quad in tqdm(all_quads):
592
- quad_str = os.path.join(*[str(q) for q in quad])
593
- datum_id_filename = quad_str + "." + "datum_id" + ".feather"
594
- path = self.projection.tile_destination / Path(datum_id_filename)
595
- tb = feather.read_table(path, memory_map=True)
693
+ logger.info("Loading tags")
694
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
695
+ datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
696
+ tb = feather.read_table(datum_id_path, memory_map=True)
596
697
  for tag in tags:
597
698
  tag_definition_id = tag["tag_definition_id"]
598
- tag_filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
599
- path = self.projection.tile_destination / Path(tag_filename)
699
+ path = self.projection.tile_destination / Path(key).with_suffix(f"._tag.{tag_definition_id}.feather")
600
700
  tag_tb = feather.read_table(path, memory_map=True)
601
701
  bitmask = None
602
702
  if "all_set" in tag_tb.column_names:
@@ -635,7 +735,7 @@ class AtlasMapTags:
635
735
  keep_tags.append(tag)
636
736
  return keep_tags
637
737
 
638
- def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False):
738
+ def get_datums_in_tag(self, tag_name: str, overwrite: bool = False):
639
739
  """
640
740
  Returns the datum ids in a given tag.
641
741
 
@@ -646,9 +746,9 @@ class AtlasMapTags:
646
746
  Returns:
647
747
  List of datum ids.
648
748
  """
649
- ordered_tag_paths = self._download_tag(tag_name, overwrite=overwrite)
749
+ tag_paths = self._download_tag(tag_name, overwrite=overwrite)
650
750
  datum_ids = []
651
- for path in ordered_tag_paths:
751
+ for path in tag_paths:
652
752
  tb = feather.read_table(path)
653
753
  last_coord = path.name.split(".")[0]
654
754
  tile_path = path.with_name(last_coord + ".datum_id.feather")
@@ -675,26 +775,14 @@ class AtlasMapTags:
675
775
  return tag
676
776
  raise ValueError(f"Tag {name} not found in projection {self.projection.id}.")
677
777
 
678
- def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False):
778
+ def _download_tag(self, tag_name: str, overwrite: bool = False):
679
779
  """
680
780
  Downloads the feather tree for large sidecar columns.
681
781
  """
682
782
  logger.info("Downloading tags")
683
- self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
684
- root_url = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
685
-
686
783
  tag = self._get_tag_by_name(tag_name)
687
784
  tag_definition_id = tag["tag_definition_id"]
688
-
689
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
690
- ordered_tag_paths = []
691
- for quad in tqdm(all_quads):
692
- quad_str = os.path.join(*[str(q) for q in quad])
693
- filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
694
- path = self.projection.tile_destination / Path(filename)
695
- download_feather(root_url + filename, path, headers=self.dataset.header, overwrite=True)
696
- ordered_tag_paths.append(path)
697
- return ordered_tag_paths
785
+ return self.projection._download_sidecar(f"_tag.{tag_definition_id}", overwrite=overwrite)
698
786
 
699
787
  def _remove_outdated_tag_files(self, tag_definition_ids: List[str]):
700
788
  """
@@ -705,14 +793,12 @@ class AtlasMapTags:
705
793
  tag_definition_ids: A list of tag definition ids to keep.
706
794
  """
707
795
  # NOTE: This currently only gets triggered on `df` property
708
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
709
- for quad in tqdm(all_quads):
710
- quad_str = os.path.join(*[str(q) for q in quad])
711
- tile = self.projection.tile_destination / Path(quad_str)
796
+ for key in self.projection._manifest["key"].to_pylist():
797
+ tile = self.projection.tile_destination / Path(key)
712
798
  tile_dir = tile.parent
713
799
  if tile_dir.exists():
714
- tagged_files = tile_dir.glob("*_tag*")
715
- for file in tagged_files:
800
+ tag_files = tile_dir.glob("*_tag*")
801
+ for file in tag_files:
716
802
  tag_definition_id = file.name.split(".")[-2]
717
803
  if tag_definition_id in tag_definition_ids:
718
804
  try:
@@ -764,61 +850,69 @@ class AtlasMapData:
764
850
  self.projection = projection
765
851
  self.dataset = projection.dataset
766
852
  self.id_field = self.projection.dataset.id_field
767
- try:
768
- # Run fetch_tiles first to guarantee existence of quad feather files
769
- self._basic_data: pa.Table = self.projection._fetch_tiles()
770
- sidecars = self._download_data(fields=fields)
771
- self._tb = self._read_prefetched_tiles_with_sidecars(sidecars)
853
+ if fields is None:
854
+ # TODO: fall back on something more reliable here
855
+ self.fields = self.dataset.dataset_fields
856
+ else:
857
+ for field in fields:
858
+ assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
859
+ self.fields = fields
860
+ self._tb = None
772
861
 
773
- except pa.lib.ArrowInvalid as e: # type: ignore
774
- raise ValueError("Failed to fetch tiles for this map")
862
+ def _load_data(self, data_columns: List[Tuple[str, str]]):
863
+ """
864
+ Loads data from a list of data columns (field and sidecar name tuples).
775
865
 
776
- def _read_prefetched_tiles_with_sidecars(self, sidecars):
866
+ Args:
867
+ data_columns: A list of tuples containing field name and sidecar name.
868
+ """
777
869
  tbs = []
778
- for path in self.projection._tiles_in_order():
779
- tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) # type: ignore
780
- for col in tb.column_names:
781
- if col[0] == "_":
782
- tb = tb.drop([col])
783
- for _, sidecar in sidecars:
784
- carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar}.feather", memory_map=True) # type: ignore
870
+
871
+ sidecars_to_load = set([sidecar for _, sidecar in data_columns if sidecar != "datum_id"])
872
+ logger.info("Loading data")
873
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
874
+ # Use datum id as root table
875
+ tb = feather.read_table(
876
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
877
+ )
878
+ for sidecar in sidecars_to_load:
879
+ path = self.projection.tile_destination
880
+ if sidecar == "":
881
+ path = path / Path(key).with_suffix(".feather")
882
+ else:
883
+ path = path / Path(key).with_suffix(f".{sidecar}.feather")
884
+ carfile = feather.read_table(path, memory_map=True)
785
885
  for col in carfile.column_names:
786
- tb = tb.append_column(col, carfile[col])
886
+ if col in self.fields:
887
+ tb = tb.append_column(col, carfile[col])
787
888
  tbs.append(tb)
788
- self._tb = pa.concat_tables(tbs)
789
889
 
790
- return self._tb
890
+ self._tb = pa.concat_tables(tbs)
791
891
 
792
- def _download_data(self, fields=None):
892
+ def _download_data(self, fields: Optional[List[str]] = None) -> List[Tuple[str, str]]:
793
893
  """
794
- Downloads the feather tree for large sidecar columns.
894
+ Downloads the feather tree for user uploaded data.
895
+
896
+ fields:
897
+ A list of fields to download. If None, downloads all fields.
898
+
899
+ Returns:
900
+ List of downloaded columns
795
901
  """
796
- logger.info("Downloading dataset")
902
+ logger.info("Downloading data")
797
903
  self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
798
- root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
799
-
800
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
801
- sidecars = None
802
- if fields is None:
803
- fields = self.dataset.dataset_fields
804
- else:
805
- for field in fields:
806
- assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
807
904
 
808
- sidecars = [
809
- (field, sidecar)
810
- for field, sidecar in self.projection._registered_sidecars()
811
- if field[0] != "_" and field in fields
905
+ # Download specified or all sidecar fields + always download datum_id
906
+ data_columns_to_load = [
907
+ (str(field), str(sidecar))
908
+ for field, sidecar in self.projection._registered_columns
909
+ if field[0] != "_" and ((field in fields) or sidecar == "datum_id")
812
910
  ]
813
911
 
814
- for quad in tqdm(all_quads):
815
- for field, encoded_colname in sidecars:
816
- quad_str = os.path.join(*[str(q) for q in quad])
817
- filename = quad_str + "." + encoded_colname + ".feather"
818
- path = self.projection.tile_destination / Path(filename)
819
- # WARNING: Potentially large data request here
820
- download_feather(root + filename, path, headers=self.dataset.header, overwrite=False)
821
- return sidecars
912
+ # TODO: less confusing progress bar
913
+ for sidecar in set([sidecar for _, sidecar in data_columns_to_load]):
914
+ self.projection._download_sidecar(sidecar)
915
+ return data_columns_to_load
822
916
 
823
917
  @property
824
918
  def df(self) -> pd.DataFrame:
@@ -826,7 +920,8 @@ class AtlasMapData:
826
920
  A pandas DataFrame associating each datapoint on your map to their metadata.
827
921
  Converting to pandas DataFrame may materialize a large amount of data into memory.
828
922
  """
829
- return self._tb.to_pandas()
923
+ logger.warning("Converting to pandas dataframe. This may materialize a large amount of data into memory.")
924
+ return self.tb.to_pandas()
830
925
 
831
926
  @property
832
927
  def tb(self) -> pa.Table:
@@ -835,4 +930,9 @@ class AtlasMapData:
835
930
  This table is memmapped from the underlying files and is the most efficient way to
836
931
  access metadata information.
837
932
  """
933
+ if isinstance(self._tb, pa.Table):
934
+ return self._tb
935
+
936
+ columns = self._download_data(fields=self.fields)
937
+ self._load_data(columns)
838
938
  return self._tb
@@ -4,28 +4,21 @@ import concurrent.futures
4
4
  import io
5
5
  import json
6
6
  import os
7
- import pickle
8
7
  import time
9
- import uuid
10
- from collections import defaultdict
11
8
  from contextlib import contextmanager
12
- from datetime import date, datetime
9
+ from datetime import datetime
13
10
  from pathlib import Path
14
- from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union, overload
11
+ from typing import Dict, List, Optional, Tuple, Union
15
12
 
16
13
  import numpy as np
17
- import pandas as pd
18
14
  import pyarrow as pa
19
15
  import requests
20
16
  from loguru import logger
21
17
  from pandas import DataFrame
22
18
  from pyarrow import compute as pc
23
19
  from pyarrow import feather, ipc
24
- from pydantic import BaseModel, Field
25
20
  from tqdm import tqdm
26
21
 
27
- import nomic
28
-
29
22
  from .cli import refresh_bearer_token, validate_api_http_response
30
23
  from .data_inference import (
31
24
  NomicDuplicatesOptions,
@@ -433,6 +426,8 @@ class AtlasProjection:
433
426
  self._tile_data = None
434
427
  self._data = None
435
428
  self._schema = None
429
+ self._manifest_tb: Optional[pa.Table] = None
430
+ self._columns: List[Tuple[str, str]] = []
436
431
 
437
432
  @property
438
433
  def map_link(self):
@@ -590,130 +585,77 @@ class AtlasProjection:
590
585
  self._schema = ipc.read_schema(io.BytesIO(content))
591
586
  return self._schema
592
587
 
593
- def _registered_sidecars(self) -> List[Tuple[str, str]]:
588
+ @property
589
+ def _registered_columns(self) -> List[Tuple[str, str]]:
594
590
  "Returns [(field_name, sidecar_name), ...]"
595
- sidecars = []
591
+ if self._columns:
592
+ return self._columns
593
+ self._columns = []
596
594
  for field in self.schema:
597
595
  sidecar_name = json.loads(field.metadata.get(b"sidecar_name", b'""'))
598
- if sidecar_name:
599
- sidecars.append((field.name, sidecar_name))
600
- return sidecars
596
+ if sidecar_name is not None:
597
+ self._columns.append((field.name, sidecar_name))
598
+ return self._columns
601
599
 
602
- def _fetch_tiles(self, overwrite: bool = False):
600
+ @property
601
+ def _manifest(self) -> pa.Table:
603
602
  """
604
- Downloads all web data for the projection to the specified directory and returns it as a memmapped arrow table.
605
-
606
- Args:
607
- overwrite: If True then overwrite web tile files.
608
-
609
- Returns:
610
- An Arrow table containing information for all data points in the index.
611
- """
612
- if self._tile_data is not None:
613
- return self._tile_data
614
- logger.info(f"Downloading files for projection {self.projection_id}")
615
- self._download_large_feather(overwrite=overwrite)
616
- tbs = []
617
- root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True)
618
- try:
619
- sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()])
620
- except KeyError:
621
- sidecars = set([])
622
- sidecars |= set(sidecar_name for (_, sidecar_name) in self._registered_sidecars())
623
- for path in self._tiles_in_order():
624
- tb = pa.feather.read_table(path, memory_map=True) # type: ignore
625
- for sidecar_file in sidecars:
626
- carfile = pa.feather.read_table( # type: ignore
627
- path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True
628
- )
629
- for col in carfile.column_names:
630
- tb = tb.append_column(col, carfile[col])
631
- tbs.append(tb)
632
- self._tile_data = pa.concat_tables(tbs)
633
-
634
- return self._tile_data
635
-
636
- @overload
637
- def _tiles_in_order(self, *, coords_only: Literal[False] = ...) -> Iterator[Path]: ...
603
+ Returns the tile manifest for the projection.
604
+ Tile manifest is in quadtree order. All quadtree operations should
605
+ depend on tile manifest to ensure consistency.
606
+ """
607
+ if self._manifest_tb is not None:
608
+ return self._manifest_tb
638
609
 
639
- @overload
640
- def _tiles_in_order(self, *, coords_only: Literal[True]) -> Iterator[Tuple[int, int, int]]: ...
610
+ manifest_path = self.tile_destination / "manifest.feather"
611
+ manifest_url = (
612
+ self.dataset.atlas_api_path
613
+ + f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/manifest.feather"
614
+ )
641
615
 
642
- @overload
643
- def _tiles_in_order(self, *, coords_only: bool) -> Iterator[Any]: ...
616
+ download_feather(manifest_url, manifest_path, headers=self.dataset.header, overwrite=False)
617
+ self._manifest_tb = feather.read_table(manifest_path, memory_map=False)
618
+ return self._manifest_tb
644
619
 
645
- def _tiles_in_order(self, *, coords_only: bool = False) -> Iterator[Any]:
620
+ def _get_sidecar_from_field(self, field: str) -> str:
646
621
  """
647
- Returns:
648
- A list of all tiles in the projection in a fixed order so that all
649
- datasets are guaranteed to be aligned.
650
- """
651
-
652
- def children(z, x, y):
653
- # This is the definition of a quadtree.
654
- return [
655
- (z + 1, x * 2, y * 2),
656
- (z + 1, x * 2 + 1, y * 2),
657
- (z + 1, x * 2, y * 2 + 1),
658
- (z + 1, x * 2 + 1, y * 2 + 1),
659
- ]
660
-
661
- # start with the root
662
- paths = [(0, 0, 0)]
663
- # Pop off the front, extend the back (breadth first traversal)
664
- while len(paths) > 0:
665
- z, x, y = paths.pop(0)
666
- path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix(".feather")
667
- if path.exists():
668
- if coords_only:
669
- yield (z, x, y)
670
- else:
671
- yield path
672
- paths.extend(children(z, x, y)) # pyright: ignore
622
+ Returns the sidecar name for a given field.
673
623
 
674
- @property
675
- def tile_destination(self):
676
- return Path("~/.nomic/cache", self.id).expanduser()
624
+ Args:
625
+ field: the name of the field
626
+ """
627
+ for f, sidecar in self._registered_columns:
628
+ if field == f:
629
+ return sidecar
630
+ raise ValueError(f"Field {field} not found in registered columns.")
677
631
 
678
- def _download_large_feather(self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True):
632
+ def _download_sidecar(self, sidecar_name, overwrite: bool = False) -> List[Path]:
679
633
  """
680
- Downloads the feather tree.
634
+ Downloads sidecar files from the quadtree
681
635
  Args:
636
+ sidecar_name: the name of the sidecar file
682
637
  overwrite: if True then overwrite existing feather files.
683
638
 
684
639
  Returns:
685
- A list containing all quadtiles downloads.
686
- """
687
- # TODO: change overwrite default to False once updating projection is removed.
688
- quads = [f"0/0/0"]
689
- self.tile_destination.mkdir(parents=True, exist_ok=True)
690
- root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/"
691
- all_quads = []
692
- sidecars = None
693
- registered_sidecars = set(sidecar_name for (_, sidecar_name) in self._registered_sidecars())
694
- while len(quads) > 0:
695
- rawquad = quads.pop(0)
696
- quad = rawquad + ".feather"
697
- all_quads.append(quad)
698
- path = self.tile_destination / quad
699
- schema = download_feather(root + quad, path, headers=self.dataset.header, overwrite=overwrite)
700
-
701
- if sidecars is None and b"sidecars" in schema.metadata:
702
- # Grab just the filenames
703
- sidecars = set([v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()])
704
- elif sidecars is None:
705
- sidecars = set()
706
- if not "." in rawquad:
707
- for sidecar in sidecars | registered_sidecars:
708
- # The sidecar loses the feather suffix because it's supposed to be raw.
709
- quads.append(quad.replace(".feather", f".{sidecar}"))
710
- if not schema.metadata or b"children" not in schema.metadata:
711
- # Sidecars don't have children.
712
- continue
713
- kids = schema.metadata.get(b"children")
714
- children = json.loads(kids)
715
- quads.extend(children)
716
- return all_quads
640
+ List of downloaded feather files.
641
+ """
642
+ downloaded_files = []
643
+ sidecar_suffix = "feather"
644
+ if sidecar_name != "":
645
+ sidecar_suffix = f"{sidecar_name}.feather"
646
+ for key in tqdm(self._manifest["key"].to_pylist()):
647
+ sidecar_path = self.tile_destination / f"{key}.{sidecar_suffix}"
648
+ sidecar_url = (
649
+ self.dataset.atlas_api_path
650
+ + f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/{key}.{sidecar_suffix}"
651
+ )
652
+ download_feather(sidecar_url, sidecar_path, headers=self.dataset.header, overwrite=overwrite)
653
+ downloaded_files.append(sidecar_path)
654
+ return downloaded_files
655
+
656
+ @property
657
+ def tile_destination(self):
658
+ return Path("~/.nomic/cache", self.id).expanduser()
717
659
 
718
660
  @property
719
661
  def datum_id_field(self):
@@ -243,30 +243,48 @@ def get_object_size_in_bytes(obj):
243
243
  # Helpful function for downloading feather files
244
244
  # Best for small feather files
245
245
  def download_feather(
246
- url: Union[str, Path], path: Path, headers: Optional[dict] = None, retries=3, overwrite=False
246
+ url: Union[str, Path], path: Path, headers: Optional[dict] = None, num_attempts=1, overwrite=False
247
247
  ) -> pa.Schema:
248
248
  """
249
249
  Download a feather file from a URL to a local path.
250
250
  Returns the schema of feather file if successful.
251
+
252
+ Parameters:
253
+ url (str): URL to download feather file from.
254
+ path (Path): Local path to save feather file to.
255
+ headers (dict): Optional headers to include in request.
256
+ num_attempts (int): Number of download attempts before raising an error.
257
+ overwrite (bool): Whether to overwrite existing file.
258
+ Returns:
259
+ Feather schema.
251
260
  """
252
- assert retries > 0, "Retries must be greater than 0"
261
+ assert num_attempts > 0, "Num attempts must be greater than 0"
253
262
  download_attempt = 0
254
263
  download_success = False
255
264
  schema = None
256
- while download_attempt < retries and not download_success:
265
+ while download_attempt < num_attempts and not download_success:
257
266
  download_attempt += 1
258
267
  if not path.exists() or overwrite:
259
- data = requests.get(str(url), headers=headers)
260
- readable = BytesIO(data.content)
261
- readable.seek(0)
262
- tb = pa.feather.read_table(readable, memory_map=False) # type: ignore
263
- path.parent.mkdir(parents=True, exist_ok=True)
264
- pa.feather.write_feather(tb, path) # type: ignore
265
- try:
266
- schema = ipc.open_file(path).schema
267
- download_success = True
268
- except pa.ArrowInvalid:
269
- path.unlink(missing_ok=True)
268
+ # Attempt download
269
+ try:
270
+ data = requests.get(str(url), headers=headers)
271
+ readable = BytesIO(data.content)
272
+ readable.seek(0)
273
+ tb = pa.feather.read_table(readable, memory_map=False) # type: ignore
274
+ schema = tb.schema
275
+ path.parent.mkdir(parents=True, exist_ok=True)
276
+ pa.feather.write_feather(tb, path) # type: ignore
277
+ download_success = True
278
+ except pa.ArrowInvalid:
279
+ # failed try again
280
+ path.unlink(missing_ok=True)
281
+ else:
282
+ # Load existing file
283
+ try:
284
+ schema = ipc.open_file(path).schema
285
+ download_success = True
286
+ except pa.ArrowInvalid:
287
+ path.unlink(missing_ok=True)
270
288
  if not download_success or schema is None:
271
- raise ValueError(f"Failed to download feather file from {url} after {retries} attempts.")
289
+ raise ValueError(f"Failed to download feather file from {url} after {num_attempts} attempts.")
272
290
  return schema
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.0.32
3
+ Version: 3.0.34
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -8,7 +8,7 @@ description = "The official Nomic python client."
8
8
 
9
9
  setup(
10
10
  name="nomic",
11
- version="3.0.32",
11
+ version="3.0.34",
12
12
  url="https://github.com/nomic-ai/nomic",
13
13
  description=description,
14
14
  long_description=description,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes