nomic 3.0.31__tar.gz → 3.0.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nomic might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.0.31
3
+ Version: 3.0.33
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -232,7 +232,6 @@ def embed_image(
232
232
  region_name: str,
233
233
  model_name="nomic-embed-vision-v1",
234
234
  ) -> dict:
235
-
236
235
  embeddings = []
237
236
 
238
237
  max_workers = mp.cpu_count()
@@ -84,11 +84,11 @@ class NomicTopicOptions(BaseModel):
84
84
 
85
85
  Args:
86
86
  build_topic_model: If True, builds a topic model over your dataset's embeddings.
87
- community_description_target_field: The dataset field/column that Atlas will use to assign a human-readable description to each topic.
87
+ topic_label_field: The dataset column (usually the column you embedded) that Atlas will use to assign a human-readable description to each topic.
88
88
  """
89
89
 
90
90
  build_topic_model: bool = True
91
- community_description_target_field: Optional[str] = Field(default=None, alias="topic_label_field")
91
+ topic_label_field: Optional[str] = Field(default=None, alias="community_description_target_field")
92
92
  cluster_method: str = "fast"
93
93
  enforce_topic_hierarchy: bool = False
94
94
 
@@ -1,15 +1,10 @@
1
1
  import base64
2
- import concurrent
3
- import concurrent.futures
4
- import glob
5
2
  import io
6
3
  import json
7
- import os
8
4
  from collections import defaultdict
9
5
  from datetime import datetime
10
- from io import BytesIO
11
6
  from pathlib import Path
12
- from typing import Dict, Iterable, List, Optional, Tuple
7
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
13
8
 
14
9
  import numpy as np
15
10
  import pandas as pd
@@ -17,12 +12,9 @@ import pyarrow as pa
17
12
  import requests
18
13
  from loguru import logger
19
14
  from pyarrow import compute as pc
20
- from pyarrow import feather, ipc
15
+ from pyarrow import feather
21
16
  from tqdm import tqdm
22
17
 
23
- from .settings import EMBEDDING_PAGINATION_LIMIT
24
- from .utils import download_feather
25
-
26
18
 
27
19
  class AtlasMapDuplicates:
28
20
  """
@@ -34,22 +26,57 @@ class AtlasMapDuplicates:
34
26
  def __init__(self, projection: "AtlasProjection"): # type: ignore
35
27
  self.projection = projection
36
28
  self.id_field = self.projection.dataset.id_field
37
- try:
38
- duplicate_fields = [
39
- field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field
40
- ]
41
- cluster_fields = [field for field in projection._fetch_tiles().column_names if "_cluster" in field]
42
- assert len(duplicate_fields) > 0, "Duplicate detection has not yet been run on this map."
43
- self.duplicate_field = duplicate_fields[0]
44
- self.cluster_field = cluster_fields[0]
45
- self._tb: pa.Table = projection._fetch_tiles().select(
46
- [self.id_field, self.duplicate_field, self.cluster_field]
29
+
30
+ duplicate_columns = [
31
+ (field, sidecar)
32
+ for field, sidecar in self.projection._registered_columns
33
+ if field.startswith("_duplicate_class")
34
+ ]
35
+ cluster_columns = [
36
+ (field, sidecar) for field, sidecar in self.projection._registered_columns if field.startswith("_cluster")
37
+ ]
38
+
39
+ assert len(duplicate_columns) > 0, "Duplicate detection has not yet been run on this map."
40
+
41
+ self._duplicate_column = duplicate_columns[0]
42
+ self._cluster_column = cluster_columns[0]
43
+ self._tb = None
44
+
45
+ def _load_duplicates(self):
46
+ """
47
+ Loads duplicates from the feather tree.
48
+ """
49
+ tbs = []
50
+ duplicate_sidecar = self._duplicate_column[1]
51
+ self.duplicate_field = self._duplicate_column[0].lstrip("_")
52
+ self.cluster_field = self._cluster_column[0].lstrip("_")
53
+ logger.info("Loading duplicates")
54
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
55
+ # Use datum id as root table
56
+ tb = feather.read_table(
57
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
47
58
  )
48
- except pa.lib.ArrowInvalid as e: # type: ignore
49
- raise ValueError("Duplicate detection has not yet been run on this map.")
50
- self.duplicate_field = self.duplicate_field.lstrip("_")
51
- self.cluster_field = self.cluster_field.lstrip("_")
52
- self._tb = self._tb.rename_columns([self.id_field, self.duplicate_field, self.cluster_field])
59
+ path = self.projection.tile_destination
60
+
61
+ if duplicate_sidecar == "":
62
+ path = path / Path(key).with_suffix(".feather")
63
+ else:
64
+ path = path / Path(key).with_suffix(f".{duplicate_sidecar}.feather")
65
+
66
+ duplicate_tb = feather.read_table(path, memory_map=True)
67
+ for field in (self._duplicate_column[0], self._cluster_column[0]):
68
+ tb = tb.append_column(field, duplicate_tb[field])
69
+ tbs.append(tb)
70
+ self._tb = pa.concat_tables(tbs).rename_columns([self.id_field, self.duplicate_field, self.cluster_field])
71
+
72
+ def _download_duplicates(self):
73
+ """
74
+ Downloads the feather tree for duplicates.
75
+ """
76
+ logger.info("Downloading duplicates")
77
+ self.projection._download_sidecar("datum_id", overwrite=False)
78
+ assert self._cluster_column[1] == self._duplicate_column[1], "Cluster and duplicate should be in same sidecar"
79
+ self.projection._download_sidecar(self._duplicate_column[1], overwrite=False)
53
80
 
54
81
  @property
55
82
  def df(self) -> pd.DataFrame:
@@ -65,6 +92,10 @@ class AtlasMapDuplicates:
65
92
  This table is memmapped from the underlying files and is the most efficient way to
66
93
  access duplicate information.
67
94
  """
95
+ if isinstance(self._tb, pa.Table):
96
+ return self._tb
97
+ self._download_duplicates()
98
+ self._load_duplicates()
68
99
  return self._tb
69
100
 
70
101
  def deletion_candidates(self) -> List[str]:
@@ -97,34 +128,69 @@ class AtlasMapTopics:
97
128
  self.id_field = self.projection.dataset.id_field
98
129
  self._metadata = None
99
130
  self._hierarchy = None
131
+ self._topic_columns = [
132
+ column for column in self.projection._registered_columns if column[0].startswith("_topic_depth_")
133
+ ]
134
+ assert len(self._topic_columns) > 0, "Topic modeling has not yet been run on this map."
135
+ self.depth = len(self._topic_columns)
136
+ self._tb = None
137
+
138
+ def _load_topics(self):
139
+ """
140
+ Loads topics from the feather tree.
141
+ """
142
+ integer_topics = False
143
+ # pd.Series to match pd typing
144
+ label_df: Optional[Union[pd.DataFrame, pd.Series]] = None
145
+ if "int" in self._topic_columns[0][0]:
146
+ integer_topics = True
147
+ label_df = self.metadata[["topic_id", "depth", "topic_short_description"]]
148
+ tbs = []
149
+ # Should just be one sidecar
150
+ topic_sidecar = set([sidecar for _, sidecar in self._topic_columns]).pop()
151
+ logger.info("Loading topics")
152
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
153
+ # Use datum id as root table
154
+ tb = feather.read_table(
155
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
156
+ )
157
+ path = self.projection.tile_destination
158
+ if topic_sidecar == "":
159
+ path = path / Path(key).with_suffix(".feather")
160
+ else:
161
+ path = path / Path(key).with_suffix(f".{topic_sidecar}.feather")
100
162
 
101
- try:
102
- self._tb: pa.Table = projection._fetch_tiles()
103
- topic_fields = [column for column in self._tb.column_names if column.startswith("_topic_depth_")]
104
- self.depth = len(topic_fields)
105
-
106
- # If using topic ids, fetch topic labels
107
- if "int" in topic_fields[0]:
108
- new_topic_fields = []
109
- label_df = self.metadata[["topic_id", "depth", "topic_short_description"]]
110
- for d in range(1, self.depth + 1):
163
+ topic_tb = feather.read_table(path, memory_map=True)
164
+ # Do this in depth order
165
+ for d in range(1, self.depth + 1):
166
+ column = f"_topic_depth_{d}"
167
+ if integer_topics:
111
168
  column = f"_topic_depth_{d}_int"
112
- topic_ids_to_label = self._tb[column].to_pandas().rename("topic_id")
169
+ topic_ids_to_label = topic_tb[column].to_pandas().rename("topic_id")
170
+ assert label_df is not None
113
171
  topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge(
114
172
  topic_ids_to_label, on="topic_id", how="right"
115
173
  )
116
174
  new_column = f"_topic_depth_{d}"
117
- self._tb = self._tb.append_column(
175
+ tb = tb.append_column(
118
176
  new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"])
119
177
  )
120
- new_topic_fields.append(new_column)
121
- topic_fields = new_topic_fields
178
+ else:
179
+ tb = tb.append_column(f"_topic_depth_1", topic_tb["_topic_depth_1"])
180
+ tbs.append(tb)
122
181
 
123
- renamed_fields = [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
124
- self._tb = self._tb.select([self.id_field] + topic_fields).rename_columns([self.id_field] + renamed_fields)
182
+ renamed_columns = [self.id_field] + [f"topic_depth_{i}" for i in range(1, self.depth + 1)]
183
+ self._tb = pa.concat_tables(tbs).rename_columns(renamed_columns)
125
184
 
126
- except pa.lib.ArrowInvalid as e: # type: ignore
127
- raise ValueError("Topic modeling has not yet been run on this map.")
185
+ def _download_topics(self):
186
+ """
187
+ Downloads the feather tree for topics.
188
+ """
189
+ logger.info("Downloading topics")
190
+ self.projection._download_sidecar("datum_id", overwrite=False)
191
+ topic_sidecars = set([sidecar for _, sidecar in self._topic_columns])
192
+ assert len(topic_sidecars) == 1, "Multiple topic sidecars found."
193
+ self.projection._download_sidecar(topic_sidecars.pop(), overwrite=False)
128
194
 
129
195
  @property
130
196
  def df(self) -> pd.DataFrame:
@@ -140,6 +206,10 @@ class AtlasMapTopics:
140
206
  This table is memmapped from the underlying files and is the most efficient way to
141
207
  access topic information.
142
208
  """
209
+ if isinstance(self._tb, pa.Table):
210
+ return self._tb
211
+ self._download_topics()
212
+ self._load_topics()
143
213
  return self._tb
144
214
 
145
215
  @property
@@ -263,8 +333,8 @@ class AtlasMapTopics:
263
333
  A list of `{topic, count}` dictionaries, sorted from largest count to smallest count.
264
334
  """
265
335
  data = AtlasMapData(self.projection, fields=[time_field])
266
- time_data = data._tb.select([self.id_field, time_field])
267
- merged_tb = self._tb.join(time_data, self.id_field, join_type="inner").combine_chunks()
336
+ time_data = data.tb.select([self.id_field, time_field])
337
+ merged_tb = self.tb.join(time_data, self.id_field, join_type="inner").combine_chunks()
268
338
 
269
339
  del time_data # free up memory
270
340
 
@@ -379,8 +449,8 @@ class AtlasMapEmbeddings:
379
449
  def __init__(self, projection: "AtlasProjection"): # type: ignore
380
450
  self.projection = projection
381
451
  self.id_field = self.projection.dataset.id_field
382
- self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, "x", "y"])
383
452
  self.dataset = projection.dataset
453
+ self._tb: pa.Table = None
384
454
  self._latent = None
385
455
 
386
456
  @property
@@ -401,6 +471,31 @@ class AtlasMapEmbeddings:
401
471
 
402
472
  Does not include high-dimensional embeddings.
403
473
  """
474
+ if isinstance(self._tb, pa.Table):
475
+ return self._tb
476
+
477
+ self._download_projected()
478
+
479
+ logger.info("Loading projected embeddings")
480
+
481
+ tbs = []
482
+ coord_sidecar = self.projection._get_sidecar_from_field("x")
483
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
484
+ # Use datum id as root table
485
+ tb = feather.read_table(
486
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
487
+ )
488
+ path = self.projection.tile_destination
489
+ if coord_sidecar == "":
490
+ path = path / Path(key).with_suffix(".feather")
491
+ else:
492
+ path = path / Path(key).with_suffix(f".{coord_sidecar}.feather")
493
+ carfile = feather.read_table(path, memory_map=True)
494
+ for col in carfile.column_names:
495
+ if col in ["x", "y"]:
496
+ tb = tb.append_column(col, carfile[col])
497
+ tbs.append(tb)
498
+ self._tb = pa.concat_tables(tbs)
404
499
  return self._tb
405
500
 
406
501
  @property
@@ -426,53 +521,43 @@ class AtlasMapEmbeddings:
426
521
  if self._latent is not None:
427
522
  return self._latent
428
523
 
429
- root_embedding = self.projection.tile_destination / "0/0/0-0.embeddings.feather"
430
- # Not the most complete check, hence the warning below.
431
- if not root_embedding.exists():
432
- self._download_latent()
524
+ downloaded_files_in_tile_order = self._download_latent()
525
+ assert len(downloaded_files_in_tile_order) > 0, "No embeddings found for this map."
433
526
  all_embeddings = []
434
-
435
- for path in self.projection._tiles_in_order(coords_only=False):
436
- # double with-suffix to remove '.embeddings.feather'
437
- files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather")
438
- # Should there be more than 10, we need to sort by int values, not string values
439
- sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1]))
440
- if len(sortable) == 0:
441
- raise FileNotFoundError(
442
- "Could not find any embeddings for tile {}".format(path)
443
- + " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'."
444
- )
445
- for file in sortable:
446
- tb = feather.read_table(file, memory_map=True)
447
- dims = tb["_embeddings"].type.list_size
448
- all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
527
+ logger.info("Loading latent embeddings")
528
+ for path in tqdm(downloaded_files_in_tile_order):
529
+ tb = feather.read_table(path, memory_map=True)
530
+ dims = tb["_embeddings"].type.list_size
531
+ all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
449
532
  return np.vstack(all_embeddings)
450
533
 
451
- def _download_latent(self):
534
+ def _download_projected(self) -> List[Path]:
452
535
  """
453
- Downloads the latent embeddings one file at a time.
536
+ Downloads the feather tree for projection coordinates.
454
537
  """
455
- logger.warning("Downloading latent embeddings of all datapoints.")
456
- limit = 10_000
457
- route = self.projection.dataset.atlas_api_path + "/v1/project/data/get/embedding/paged"
458
- last = None
538
+ logger.info("Downloading projected embeddings")
539
+ # Note that y coord should be in same sidecar
540
+ coord_sidecar = self.projection._get_sidecar_from_field("x")
541
+ self.projection._download_sidecar("datum_id", overwrite=False)
542
+ return self.projection._download_sidecar(coord_sidecar, overwrite=False)
459
543
 
460
- with tqdm(total=self.dataset.total_datums // limit) as pbar:
461
- while True:
462
- params = {"projection_id": self.projection.id, "last_file": last, "page_size": limit}
463
- r = requests.post(route, headers=self.projection.dataset.header, json=params)
464
- if r.status_code == 204:
465
- # Download complete!
466
- break
467
- fin = BytesIO(r.content)
468
- tb = feather.read_table(fin, memory_map=True)
544
+ def _download_latent(self) -> List[Path]:
545
+ """
546
+ Downloads the feather tree for embeddings.
547
+ Returns the path to downloaded embeddings.
548
+ """
549
+ # TODO: Is size of the embedding files (several hundreds of MBs) going to be a problem here?
550
+ logger.info("Downloading latent embeddings")
551
+ embedding_sidecar = None
552
+ for field, sidecar in self.projection._registered_columns:
553
+ # NOTE: be _embeddings or _embedding
554
+ if field == "_embeddings":
555
+ embedding_sidecar = sidecar
556
+ break
469
557
 
470
- tilename = tb.schema.metadata[b"tile"].decode("utf-8")
471
- dest = (self.projection.tile_destination / tilename).with_suffix(".embeddings.feather")
472
- dest.parent.mkdir(parents=True, exist_ok=True)
473
- feather.write_feather(tb, dest)
474
- last = tilename
475
- pbar.update(1)
558
+ if embedding_sidecar is None:
559
+ raise ValueError("No embeddings found for this map.")
560
+ return self.projection._download_sidecar(embedding_sidecar, overwrite=False)
476
561
 
477
562
  def vector_search(
478
563
  self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5
@@ -586,12 +671,15 @@ class AtlasMapTags:
586
671
  self.projection = projection
587
672
  self.dataset = projection.dataset
588
673
  self.id_field = self.projection.dataset.id_field
589
- # Pre-fetch tiles first upon initialization
590
- self.projection._fetch_tiles(overwrite=False)
674
+ # Pre-fetch datum ids first upon initialization
675
+ try:
676
+ self.projection._download_sidecar("datum_id")
677
+ except Exception:
678
+ raise ValueError("Failed to fetch datum ids which is required to load tags.")
591
679
  self.auto_cleanup = auto_cleanup
592
680
 
593
681
  @property
594
- def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame:
682
+ def df(self, overwrite: bool = False) -> pd.DataFrame:
595
683
  """
596
684
  Pandas DataFrame mapping each data point to its tags.
597
685
  """
@@ -602,16 +690,13 @@ class AtlasMapTags:
602
690
  for tag in tags:
603
691
  self._download_tag(tag["tag_name"], overwrite=overwrite)
604
692
  tbs = []
605
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
606
- for quad in tqdm(all_quads):
607
- quad_str = os.path.join(*[str(q) for q in quad])
608
- datum_id_filename = quad_str + "." + "datum_id" + ".feather"
609
- path = self.projection.tile_destination / Path(datum_id_filename)
610
- tb = feather.read_table(path, memory_map=True)
693
+ logger.info("Loading tags")
694
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
695
+ datum_id_path = self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather")
696
+ tb = feather.read_table(datum_id_path, memory_map=True)
611
697
  for tag in tags:
612
698
  tag_definition_id = tag["tag_definition_id"]
613
- tag_filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
614
- path = self.projection.tile_destination / Path(tag_filename)
699
+ path = self.projection.tile_destination / Path(key).with_suffix(f"._tag.{tag_definition_id}.feather")
615
700
  tag_tb = feather.read_table(path, memory_map=True)
616
701
  bitmask = None
617
702
  if "all_set" in tag_tb.column_names:
@@ -650,7 +735,7 @@ class AtlasMapTags:
650
735
  keep_tags.append(tag)
651
736
  return keep_tags
652
737
 
653
- def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False):
738
+ def get_datums_in_tag(self, tag_name: str, overwrite: bool = False):
654
739
  """
655
740
  Returns the datum ids in a given tag.
656
741
 
@@ -661,9 +746,9 @@ class AtlasMapTags:
661
746
  Returns:
662
747
  List of datum ids.
663
748
  """
664
- ordered_tag_paths = self._download_tag(tag_name, overwrite=overwrite)
749
+ tag_paths = self._download_tag(tag_name, overwrite=overwrite)
665
750
  datum_ids = []
666
- for path in ordered_tag_paths:
751
+ for path in tag_paths:
667
752
  tb = feather.read_table(path)
668
753
  last_coord = path.name.split(".")[0]
669
754
  tile_path = path.with_name(last_coord + ".datum_id.feather")
@@ -690,38 +775,14 @@ class AtlasMapTags:
690
775
  return tag
691
776
  raise ValueError(f"Tag {name} not found in projection {self.projection.id}.")
692
777
 
693
- def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False):
778
+ def _download_tag(self, tag_name: str, overwrite: bool = False):
694
779
  """
695
780
  Downloads the feather tree for large sidecar columns.
696
781
  """
697
- self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
698
- root_url = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
699
-
782
+ logger.info("Downloading tags")
700
783
  tag = self._get_tag_by_name(tag_name)
701
784
  tag_definition_id = tag["tag_definition_id"]
702
-
703
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
704
- ordered_tag_paths = []
705
- for quad in tqdm(all_quads):
706
- quad_str = os.path.join(*[str(q) for q in quad])
707
- filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
708
- path = self.projection.tile_destination / Path(filename)
709
- download_attempt = 0
710
- download_success = False
711
- while download_attempt < 3 and not download_success:
712
- download_attempt += 1
713
- if not path.exists() or overwrite:
714
- download_feather(root_url + filename, path, headers=self.dataset.header)
715
- try:
716
- ipc.open_file(path).schema
717
- download_success = True
718
- except pa.ArrowInvalid:
719
- path.unlink(missing_ok=True)
720
-
721
- if not download_success:
722
- raise Exception(f"Failed to download tag {tag_name}.")
723
- ordered_tag_paths.append(path)
724
- return ordered_tag_paths
785
+ return self.projection._download_sidecar(f"_tag.{tag_definition_id}", overwrite=overwrite)
725
786
 
726
787
  def _remove_outdated_tag_files(self, tag_definition_ids: List[str]):
727
788
  """
@@ -732,14 +793,12 @@ class AtlasMapTags:
732
793
  tag_definition_ids: A list of tag definition ids to keep.
733
794
  """
734
795
  # NOTE: This currently only gets triggered on `df` property
735
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
736
- for quad in tqdm(all_quads):
737
- quad_str = os.path.join(*[str(q) for q in quad])
738
- tile = self.projection.tile_destination / Path(quad_str)
796
+ for key in self.projection._manifest["key"].to_pylist():
797
+ tile = self.projection.tile_destination / Path(key)
739
798
  tile_dir = tile.parent
740
799
  if tile_dir.exists():
741
- tagged_files = tile_dir.glob("*_tag*")
742
- for file in tagged_files:
800
+ tag_files = tile_dir.glob("*_tag*")
801
+ for file in tag_files:
743
802
  tag_definition_id = file.name.split(".")[-2]
744
803
  if tag_definition_id in tag_definition_ids:
745
804
  try:
@@ -791,81 +850,69 @@ class AtlasMapData:
791
850
  self.projection = projection
792
851
  self.dataset = projection.dataset
793
852
  self.id_field = self.projection.dataset.id_field
794
- self.fields = fields
795
- try:
796
- # Run fetch_tiles first to guarantee existence of quad feather files
797
- self._basic_data: pa.Table = self.projection._fetch_tiles()
798
- sidecars = self._download_data(fields=fields)
799
- self._tb = self._read_prefetched_tiles_with_sidecars(sidecars)
853
+ if fields is None:
854
+ # TODO: fall back on something more reliable here
855
+ self.fields = self.dataset.dataset_fields
856
+ else:
857
+ for field in fields:
858
+ assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
859
+ self.fields = fields
860
+ self._tb = None
800
861
 
801
- except pa.lib.ArrowInvalid as e: # type: ignore
802
- raise ValueError("Failed to fetch tiles for this map")
862
+ def _load_data(self, data_columns: List[Tuple[str, str]]):
863
+ """
864
+ Loads data from a list of data columns (field and sidecar name tuples).
803
865
 
804
- def _read_prefetched_tiles_with_sidecars(self, additional_sidecars):
866
+ Args:
867
+ data_columns: A list of tuples containing field name and sidecar name.
868
+ """
805
869
  tbs = []
806
- root = feather.read_table(self.projection.tile_destination / Path("0/0/0.feather")) # type: ignore
807
- try:
808
- small_sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()])
809
- except KeyError:
810
- small_sidecars = set([])
811
- for path in self.projection._tiles_in_order():
812
- tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) # type: ignore
813
- for col in tb.column_names:
814
- if col[0] == "_":
815
- tb = tb.drop([col])
816
- for sidecar_file in small_sidecars:
817
- carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) # type: ignore
818
- for col in carfile.column_names:
819
- tb = tb.append_column(col, carfile[col])
820
- for big_sidecar in additional_sidecars:
821
- fname = (
822
- base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8")
823
- if big_sidecar != "datum_id"
824
- else big_sidecar
825
- )
826
- carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) # type: ignore
870
+
871
+ sidecars_to_load = set([sidecar for _, sidecar in data_columns if sidecar != "datum_id"])
872
+ logger.info("Loading data")
873
+ for key in tqdm(self.projection._manifest["key"].to_pylist()):
874
+ # Use datum id as root table
875
+ tb = feather.read_table(
876
+ self.projection.tile_destination / Path(key).with_suffix(".datum_id.feather"), memory_map=True
877
+ )
878
+ for sidecar in sidecars_to_load:
879
+ path = self.projection.tile_destination
880
+ if sidecar == "":
881
+ path = path / Path(key).with_suffix(".feather")
882
+ else:
883
+ path = path / Path(key).with_suffix(f".{sidecar}.feather")
884
+ carfile = feather.read_table(path, memory_map=True)
827
885
  for col in carfile.column_names:
828
- tb = tb.append_column(col, carfile[col])
886
+ if col in self.fields:
887
+ tb = tb.append_column(col, carfile[col])
829
888
  tbs.append(tb)
830
- self._tb = pa.concat_tables(tbs)
831
889
 
832
- return self._tb
890
+ self._tb = pa.concat_tables(tbs)
833
891
 
834
- def _download_data(self, fields=None):
892
+ def _download_data(self, fields: Optional[List[str]] = None) -> List[Tuple[str, str]]:
835
893
  """
836
- Downloads the feather tree for large sidecar columns.
894
+ Downloads the feather tree for user uploaded data.
895
+
896
+ fields:
897
+ A list of fields to download. If None, downloads all fields.
898
+
899
+ Returns:
900
+ List of downloaded columns
837
901
  """
902
+ logger.info("Downloading data")
838
903
  self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
839
- root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
840
-
841
- all_quads = list(self.projection._tiles_in_order(coords_only=True))
842
- sidecars = fields
843
- registered_sidecars = self.projection._registered_sidecars()
844
- if sidecars is None:
845
- sidecars = [
846
- field
847
- for field in self.dataset.dataset_fields
848
- if field not in self._basic_data.column_names and field != "_embeddings"
849
- ]
850
- else:
851
- for field in sidecars:
852
- assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
853
- encoded_sidecars = [base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") for sidecar in sidecars]
854
- if any(sidecar == "datum_id" for (field, sidecar) in registered_sidecars):
855
- sidecars.append("datum_id")
856
- encoded_sidecars.append("datum_id")
857
-
858
- for quad in tqdm(all_quads):
859
- for encoded_colname in encoded_sidecars:
860
- quad_str = os.path.join(*[str(q) for q in quad])
861
- filename = quad_str + "." + encoded_colname + ".feather"
862
- path = self.projection.tile_destination / Path(filename)
863
904
 
864
- if not os.path.exists(path):
865
- # WARNING: Potentially large data request here
866
- download_feather(root + filename, path, headers=self.dataset.header)
905
+ # Download specified or all sidecar fields + always download datum_id
906
+ data_columns_to_load = [
907
+ (str(field), str(sidecar))
908
+ for field, sidecar in self.projection._registered_columns
909
+ if field[0] != "_" and ((field in fields) or sidecar == "datum_id")
910
+ ]
867
911
 
868
- return sidecars
912
+ # TODO: less confusing progress bar
913
+ for sidecar in set([sidecar for _, sidecar in data_columns_to_load]):
914
+ self.projection._download_sidecar(sidecar)
915
+ return data_columns_to_load
869
916
 
870
917
  @property
871
918
  def df(self) -> pd.DataFrame:
@@ -873,7 +920,8 @@ class AtlasMapData:
873
920
  A pandas DataFrame associating each datapoint on your map to their metadata.
874
921
  Converting to pandas DataFrame may materialize a large amount of data into memory.
875
922
  """
876
- return self._tb.to_pandas()
923
+ logger.warning("Converting to pandas dataframe. This may materialize a large amount of data into memory.")
924
+ return self.tb.to_pandas()
877
925
 
878
926
  @property
879
927
  def tb(self) -> pa.Table:
@@ -882,4 +930,9 @@ class AtlasMapData:
882
930
  This table is memmapped from the underlying files and is the most efficient way to
883
931
  access metadata information.
884
932
  """
933
+ if isinstance(self._tb, pa.Table):
934
+ return self._tb
935
+
936
+ columns = self._download_data(fields=self.fields)
937
+ self._load_data(columns)
885
938
  return self._tb
@@ -4,28 +4,21 @@ import concurrent.futures
4
4
  import io
5
5
  import json
6
6
  import os
7
- import pickle
8
7
  import time
9
- import uuid
10
- from collections import defaultdict
11
8
  from contextlib import contextmanager
12
- from datetime import date, datetime
9
+ from datetime import datetime
13
10
  from pathlib import Path
14
- from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union, overload
11
+ from typing import Dict, List, Optional, Tuple, Union
15
12
 
16
13
  import numpy as np
17
- import pandas as pd
18
14
  import pyarrow as pa
19
15
  import requests
20
16
  from loguru import logger
21
17
  from pandas import DataFrame
22
18
  from pyarrow import compute as pc
23
19
  from pyarrow import feather, ipc
24
- from pydantic import BaseModel, Field
25
20
  from tqdm import tqdm
26
21
 
27
- import nomic
28
-
29
22
  from .cli import refresh_bearer_token, validate_api_http_response
30
23
  from .data_inference import (
31
24
  NomicDuplicatesOptions,
@@ -36,7 +29,7 @@ from .data_inference import (
36
29
  )
37
30
  from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics
38
31
  from .settings import *
39
- from .utils import assert_valid_project_id, get_object_size_in_bytes
32
+ from .utils import assert_valid_project_id, download_feather
40
33
 
41
34
 
42
35
  class AtlasUser:
@@ -433,6 +426,8 @@ class AtlasProjection:
433
426
  self._tile_data = None
434
427
  self._data = None
435
428
  self._schema = None
429
+ self._manifest_tb: Optional[pa.Table] = None
430
+ self._columns: List[Tuple[str, str]] = []
436
431
 
437
432
  @property
438
433
  def map_link(self):
@@ -590,147 +585,77 @@ class AtlasProjection:
590
585
  self._schema = ipc.read_schema(io.BytesIO(content))
591
586
  return self._schema
592
587
 
593
- def _registered_sidecars(self) -> List[Tuple[str, str]]:
588
+ @property
589
+ def _registered_columns(self) -> List[Tuple[str, str]]:
594
590
  "Returns [(field_name, sidecar_name), ...]"
595
- sidecars = []
591
+ if self._columns:
592
+ return self._columns
593
+ self._columns = []
596
594
  for field in self.schema:
597
595
  sidecar_name = json.loads(field.metadata.get(b"sidecar_name", b'""'))
598
- if sidecar_name:
599
- sidecars.append((field.name, sidecar_name))
600
- return sidecars
596
+ if sidecar_name is not None:
597
+ self._columns.append((field.name, sidecar_name))
598
+ return self._columns
601
599
 
602
- def _fetch_tiles(self, overwrite: bool = False):
600
+ @property
601
+ def _manifest(self) -> pa.Table:
603
602
  """
604
- Downloads all web data for the projection to the specified directory and returns it as a memmapped arrow table.
605
-
606
- Args:
607
- overwrite: If True then overwrite web tile files.
608
-
609
- Returns:
610
- An Arrow table containing information for all data points in the index.
603
+ Returns the tile manifest for the projection.
604
+ Tile manifest is in quadtree order. All quadtree operations should
605
+ depend on tile manifest to ensure consistency.
611
606
  """
612
- if self._tile_data is not None:
613
- return self._tile_data
614
- self._download_large_feather(overwrite=overwrite)
615
- tbs = []
616
- root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True)
617
- try:
618
- sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()])
619
- except KeyError:
620
- sidecars = set([])
621
- sidecars |= set(sidecar_name for (_, sidecar_name) in self._registered_sidecars())
622
- for path in self._tiles_in_order():
623
- tb = pa.feather.read_table(path, memory_map=True) # type: ignore
624
- for sidecar_file in sidecars:
625
- carfile = pa.feather.read_table( # type: ignore
626
- path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True
627
- )
628
- for col in carfile.column_names:
629
- tb = tb.append_column(col, carfile[col])
630
- tbs.append(tb)
631
- self._tile_data = pa.concat_tables(tbs)
607
+ if self._manifest_tb is not None:
608
+ return self._manifest_tb
632
609
 
633
- return self._tile_data
610
+ manifest_path = self.tile_destination / "manifest.feather"
611
+ manifest_url = (
612
+ self.dataset.atlas_api_path
613
+ + f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/manifest.feather"
614
+ )
634
615
 
635
- @overload
636
- def _tiles_in_order(self, *, coords_only: Literal[False] = ...) -> Iterator[Path]: ...
637
- @overload
638
- def _tiles_in_order(self, *, coords_only: Literal[True]) -> Iterator[Tuple[int, int, int]]: ...
639
- @overload
640
- def _tiles_in_order(self, *, coords_only: bool) -> Iterator[Any]: ...
616
+ download_feather(manifest_url, manifest_path, headers=self.dataset.header, overwrite=False)
617
+ self._manifest_tb = feather.read_table(manifest_path, memory_map=False)
618
+ return self._manifest_tb
641
619
 
642
- def _tiles_in_order(self, *, coords_only: bool = False) -> Iterator[Any]:
620
+ def _get_sidecar_from_field(self, field: str) -> str:
643
621
  """
644
- Returns:
645
- A list of all tiles in the projection in a fixed order so that all
646
- datasets are guaranteed to be aligned.
647
- """
648
-
649
- def children(z, x, y):
650
- # This is the definition of a quadtree.
651
- return [
652
- (z + 1, x * 2, y * 2),
653
- (z + 1, x * 2 + 1, y * 2),
654
- (z + 1, x * 2, y * 2 + 1),
655
- (z + 1, x * 2 + 1, y * 2 + 1),
656
- ]
657
-
658
- # start with the root
659
- paths = [(0, 0, 0)]
660
- # Pop off the front, extend the back (breadth first traversal)
661
- while len(paths) > 0:
662
- z, x, y = paths.pop(0)
663
- path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix(".feather")
664
- if path.exists():
665
- if coords_only:
666
- yield (z, x, y)
667
- else:
668
- yield path
669
- paths.extend(children(z, x, y)) # pyright: ignore
622
+ Returns the sidecar name for a given field.
670
623
 
671
- @property
672
- def tile_destination(self):
673
- return Path("~/.nomic/cache", self.id).expanduser()
624
+ Args:
625
+ field: the name of the field
626
+ """
627
+ for f, sidecar in self._registered_columns:
628
+ if field == f:
629
+ return sidecar
630
+ raise ValueError(f"Field {field} not found in registered columns.")
674
631
 
675
- def _download_large_feather(self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True):
632
+ def _download_sidecar(self, sidecar_name, overwrite: bool = False) -> List[Path]:
676
633
  """
677
- Downloads the feather tree.
634
+ Downloads sidecar files from the quadtree
678
635
  Args:
636
+ sidecar_name: the name of the sidecar file
679
637
  overwrite: if True then overwrite existing feather files.
680
638
 
681
639
  Returns:
682
- A list containing all quadtiles downloads.
683
- """
684
- # TODO: change overwrite default to False once updating projection is removed.
685
- quads = [f"0/0/0"]
686
- self.tile_destination.mkdir(parents=True, exist_ok=True)
687
- root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/"
688
- all_quads = []
689
- sidecars = None
690
- registered_sidecars = set(sidecar_name for (_, sidecar_name) in self._registered_sidecars())
691
- while len(quads) > 0:
692
- rawquad = quads.pop(0)
693
- quad = rawquad + ".feather"
694
- all_quads.append(quad)
695
- path = self.tile_destination / quad
696
-
697
- download_attempt = 0
698
- download_success = False
699
- schema = None
700
- while download_attempt < 3 and not download_success:
701
- download_attempt += 1
702
- if not path.exists() or overwrite:
703
- data = requests.get(root + quad, headers=self.dataset.header)
704
- readable = io.BytesIO(data.content)
705
- readable.seek(0)
706
- tb = feather.read_table(readable, memory_map=True)
707
- path.parent.mkdir(parents=True, exist_ok=True)
708
- feather.write_feather(tb, path)
709
- try:
710
- schema = ipc.open_file(path).schema
711
- download_success = True
712
- except pa.ArrowInvalid:
713
- path.unlink(missing_ok=True)
714
-
715
- if not download_success or schema is None:
716
- raise Exception(f"Failed to download tiles. Aborting...")
717
-
718
- if sidecars is None and b"sidecars" in schema.metadata:
719
- # Grab just the filenames
720
- sidecars = set([v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()])
721
- elif sidecars is None:
722
- sidecars = set()
723
- if not "." in rawquad:
724
- for sidecar in sidecars | registered_sidecars:
725
- # The sidecar loses the feather suffix because it's supposed to be raw.
726
- quads.append(quad.replace(".feather", f".{sidecar}"))
727
- if not schema.metadata or b"children" not in schema.metadata:
728
- # Sidecars don't have children.
729
- continue
730
- kids = schema.metadata.get(b"children")
731
- children = json.loads(kids)
732
- quads.extend(children)
733
- return all_quads
640
+ List of downloaded feather files.
641
+ """
642
+ downloaded_files = []
643
+ sidecar_suffix = "feather"
644
+ if sidecar_name != "":
645
+ sidecar_suffix = f"{sidecar_name}.feather"
646
+ for key in tqdm(self._manifest["key"].to_pylist()):
647
+ sidecar_path = self.tile_destination / f"{key}.{sidecar_suffix}"
648
+ sidecar_url = (
649
+ self.dataset.atlas_api_path
650
+ + f"/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/{key}.{sidecar_suffix}"
651
+ )
652
+ download_feather(sidecar_url, sidecar_path, headers=self.dataset.header, overwrite=overwrite)
653
+ downloaded_files.append(sidecar_path)
654
+ return downloaded_files
655
+
656
+ @property
657
+ def tile_destination(self):
658
+ return Path("~/.nomic/cache", self.id).expanduser()
734
659
 
735
660
  @property
736
661
  def datum_id_field(self):
@@ -1160,7 +1085,7 @@ class AtlasDataset(AtlasClass):
1160
1085
 
1161
1086
  build_template = {}
1162
1087
  if self.modality == "embedding":
1163
- if topic_model.community_description_target_field is None:
1088
+ if topic_model.topic_label_field is None:
1164
1089
  logger.warning(
1165
1090
  "You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics."
1166
1091
  )
@@ -1188,7 +1113,7 @@ class AtlasDataset(AtlasClass):
1188
1113
  "topic_model_hyperparameters": json.dumps(
1189
1114
  {
1190
1115
  "build_topic_model": topic_model.build_topic_model,
1191
- "community_description_target_field": topic_model.community_description_target_field,
1116
+ "community_description_target_field": topic_model.topic_label_field, # TODO change key to topic_label_field post v0.0.85
1192
1117
  "cluster_method": topic_model.cluster_method,
1193
1118
  "enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
1194
1119
  }
@@ -1253,7 +1178,7 @@ class AtlasDataset(AtlasClass):
1253
1178
  "topic_model_hyperparameters": json.dumps(
1254
1179
  {
1255
1180
  "build_topic_model": topic_model.build_topic_model,
1256
- "community_description_target_field": indexed_field,
1181
+ "community_description_target_field": indexed_field, # TODO change key to topic_label_field post v0.0.85
1257
1182
  "cluster_method": topic_model.build_topic_model,
1258
1183
  "enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
1259
1184
  }
@@ -4,11 +4,12 @@ import random
4
4
  import sys
5
5
  from io import BytesIO
6
6
  from pathlib import Path
7
- from typing import Optional
7
+ from typing import Optional, Union
8
8
  from uuid import UUID
9
9
 
10
10
  import pyarrow as pa
11
11
  import requests
12
+ from pyarrow import ipc
12
13
 
13
14
  nouns = [
14
15
  "newton",
@@ -241,10 +242,49 @@ def get_object_size_in_bytes(obj):
241
242
 
242
243
  # Helpful function for downloading feather files
243
244
  # Best for small feather files
244
- def download_feather(url: str, path: Path, headers: Optional[dict] = None):
245
- data = requests.get(url, headers=headers)
246
- readable = BytesIO(data.content)
247
- readable.seek(0)
248
- tb = pa.feather.read_table(readable, memory_map=True) # type: ignore
249
- path.parent.mkdir(parents=True, exist_ok=True)
250
- pa.feather.write_feather(tb, path) # type: ignore
245
+ def download_feather(
246
+ url: Union[str, Path], path: Path, headers: Optional[dict] = None, num_attempts=1, overwrite=False
247
+ ) -> pa.Schema:
248
+ """
249
+ Download a feather file from a URL to a local path.
250
+ Returns the schema of feather file if successful.
251
+
252
+ Parameters:
253
+ url (str): URL to download feather file from.
254
+ path (Path): Local path to save feather file to.
255
+ headers (dict): Optional headers to include in request.
256
+ num_attempts (int): Number of download attempts before raising an error.
257
+ overwrite (bool): Whether to overwrite existing file.
258
+ Returns:
259
+ Feather schema.
260
+ """
261
+ assert num_attempts > 0, "Num attempts must be greater than 0"
262
+ download_attempt = 0
263
+ download_success = False
264
+ schema = None
265
+ while download_attempt < num_attempts and not download_success:
266
+ download_attempt += 1
267
+ if not path.exists() or overwrite:
268
+ # Attempt download
269
+ try:
270
+ data = requests.get(str(url), headers=headers)
271
+ readable = BytesIO(data.content)
272
+ readable.seek(0)
273
+ tb = pa.feather.read_table(readable, memory_map=False) # type: ignore
274
+ schema = tb.schema
275
+ path.parent.mkdir(parents=True, exist_ok=True)
276
+ pa.feather.write_feather(tb, path) # type: ignore
277
+ download_success = True
278
+ except pa.ArrowInvalid:
279
+ # failed try again
280
+ path.unlink(missing_ok=True)
281
+ else:
282
+ # Load existing file
283
+ try:
284
+ schema = ipc.open_file(path).schema
285
+ download_success = True
286
+ except pa.ArrowInvalid:
287
+ path.unlink(missing_ok=True)
288
+ if not download_success or schema is None:
289
+ raise ValueError(f"Failed to download feather file from {url} after {num_attempts} attempts.")
290
+ return schema
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.0.31
3
+ Version: 3.0.33
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -20,7 +20,7 @@ sagemaker
20
20
 
21
21
  [dev]
22
22
  nomic[all]
23
- black
23
+ black==24.3.0
24
24
  coverage
25
25
  pylint
26
26
  pytest
@@ -8,7 +8,7 @@ description = "The official Nomic python client."
8
8
 
9
9
  setup(
10
10
  name="nomic",
11
- version="3.0.31",
11
+ version="3.0.33",
12
12
  url="https://github.com/nomic-ai/nomic",
13
13
  description=description,
14
14
  long_description=description,
@@ -43,7 +43,7 @@ setup(
43
43
  ],
44
44
  "dev": [
45
45
  "nomic[all]",
46
- "black",
46
+ "black==24.3.0",
47
47
  "coverage",
48
48
  "pylint",
49
49
  "pytest",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes