nomic 3.0.31__tar.gz → 3.0.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.0.31
3
+ Version: 3.0.32
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -232,7 +232,6 @@ def embed_image(
232
232
  region_name: str,
233
233
  model_name="nomic-embed-vision-v1",
234
234
  ) -> dict:
235
-
236
235
  embeddings = []
237
236
 
238
237
  max_workers = mp.cpu_count()
@@ -84,11 +84,11 @@ class NomicTopicOptions(BaseModel):
84
84
 
85
85
  Args:
86
86
  build_topic_model: If True, builds a topic model over your dataset's embeddings.
87
- community_description_target_field: The dataset field/column that Atlas will use to assign a human-readable description to each topic.
87
+ topic_label_field: The dataset column (usually the column you embedded) that Atlas will use to assign a human-readable description to each topic.
88
88
  """
89
89
 
90
90
  build_topic_model: bool = True
91
- community_description_target_field: Optional[str] = Field(default=None, alias="topic_label_field")
91
+ topic_label_field: Optional[str] = Field(default=None, alias="community_description_target_field")
92
92
  cluster_method: str = "fast"
93
93
  enforce_topic_hierarchy: bool = False
94
94
 
@@ -1,13 +1,9 @@
1
1
  import base64
2
- import concurrent
3
- import concurrent.futures
4
- import glob
5
2
  import io
6
3
  import json
7
4
  import os
8
5
  from collections import defaultdict
9
6
  from datetime import datetime
10
- from io import BytesIO
11
7
  from pathlib import Path
12
8
  from typing import Dict, Iterable, List, Optional, Tuple
13
9
 
@@ -17,7 +13,7 @@ import pyarrow as pa
17
13
  import requests
18
14
  from loguru import logger
19
15
  from pyarrow import compute as pc
20
- from pyarrow import feather, ipc
16
+ from pyarrow import feather
21
17
  from tqdm import tqdm
22
18
 
23
19
  from .settings import EMBEDDING_PAGINATION_LIMIT
@@ -99,6 +95,7 @@ class AtlasMapTopics:
99
95
  self._hierarchy = None
100
96
 
101
97
  try:
98
+ logger.info("Downloading topics")
102
99
  self._tb: pa.Table = projection._fetch_tiles()
103
100
  topic_fields = [column for column in self._tb.column_names if column.startswith("_topic_depth_")]
104
101
  self.depth = len(topic_fields)
@@ -426,53 +423,41 @@ class AtlasMapEmbeddings:
426
423
  if self._latent is not None:
427
424
  return self._latent
428
425
 
429
- root_embedding = self.projection.tile_destination / "0/0/0-0.embeddings.feather"
430
- # Not the most complete check, hence the warning below.
431
- if not root_embedding.exists():
432
- self._download_latent()
426
+ downloaded_files_in_tile_order = self._download_latent()
427
+ assert len(downloaded_files_in_tile_order) > 0, "No embeddings found for this map."
433
428
  all_embeddings = []
434
429
 
435
- for path in self.projection._tiles_in_order(coords_only=False):
436
- # double with-suffix to remove '.embeddings.feather'
437
- files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather")
430
+ for path in downloaded_files_in_tile_order:
438
431
  # Should there be more than 10, we need to sort by int values, not string values
439
- sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1]))
440
- if len(sortable) == 0:
441
- raise FileNotFoundError(
442
- "Could not find any embeddings for tile {}".format(path)
443
- + " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'."
444
- )
445
- for file in sortable:
446
- tb = feather.read_table(file, memory_map=True)
447
- dims = tb["_embeddings"].type.list_size
448
- all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
432
+ tb = feather.read_table(path, memory_map=True)
433
+ dims = tb["_embeddings"].type.list_size
434
+ all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
449
435
  return np.vstack(all_embeddings)
450
436
 
451
- def _download_latent(self):
437
+ def _download_latent(self) -> List[Path]:
452
438
  """
453
- Downloads the latent embeddings one file at a time.
439
+ Downloads the feather tree for embeddings.
440
+ Returns the path to downloaded embeddings.
454
441
  """
455
- logger.warning("Downloading latent embeddings of all datapoints.")
456
- limit = 10_000
457
- route = self.projection.dataset.atlas_api_path + "/v1/project/data/get/embedding/paged"
458
- last = None
442
+ # TODO: Is size of the embedding files (several hundreds of MBs) going to be a problem here?
443
+ self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
444
+ root_url = Path(
445
+ f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
446
+ )
459
447
 
460
- with tqdm(total=self.dataset.total_datums // limit) as pbar:
461
- while True:
462
- params = {"projection_id": self.projection.id, "last_file": last, "page_size": limit}
463
- r = requests.post(route, headers=self.projection.dataset.header, json=params)
464
- if r.status_code == 204:
465
- # Download complete!
466
- break
467
- fin = BytesIO(r.content)
468
- tb = feather.read_table(fin, memory_map=True)
448
+ registered_sidecar_names = [sidecar[1] for sidecar in self.projection._registered_sidecars()]
449
+ assert "embeddings" in registered_sidecar_names, "Embeddings not found in sidecars."
469
450
 
470
- tilename = tb.schema.metadata[b"tile"].decode("utf-8")
471
- dest = (self.projection.tile_destination / tilename).with_suffix(".embeddings.feather")
472
- dest.parent.mkdir(parents=True, exist_ok=True)
473
- feather.write_feather(tb, dest)
474
- last = tilename
475
- pbar.update(1)
451
+ downloaded_files_in_tile_order = []
452
+ logger.info("Downloading latent embeddings...")
453
+ all_quads = list(self.projection._tiles_in_order())
454
+ for quad in tqdm(all_quads):
455
+ path = quad.with_suffix(".embeddings.feather")
456
+ # WARNING: Potentially large data request here
457
+ quadtree_loc = Path(*path.parts[-3:])
458
+ download_feather(root_url / quadtree_loc, path, headers=self.dataset.header, overwrite=False)
459
+ downloaded_files_in_tile_order.append(path)
460
+ return downloaded_files_in_tile_order
476
461
 
477
462
  def vector_search(
478
463
  self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5
@@ -694,6 +679,7 @@ class AtlasMapTags:
694
679
  """
695
680
  Downloads the feather tree for large sidecar columns.
696
681
  """
682
+ logger.info("Downloading tags")
697
683
  self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
698
684
  root_url = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
699
685
 
@@ -706,20 +692,7 @@ class AtlasMapTags:
706
692
  quad_str = os.path.join(*[str(q) for q in quad])
707
693
  filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
708
694
  path = self.projection.tile_destination / Path(filename)
709
- download_attempt = 0
710
- download_success = False
711
- while download_attempt < 3 and not download_success:
712
- download_attempt += 1
713
- if not path.exists() or overwrite:
714
- download_feather(root_url + filename, path, headers=self.dataset.header)
715
- try:
716
- ipc.open_file(path).schema
717
- download_success = True
718
- except pa.ArrowInvalid:
719
- path.unlink(missing_ok=True)
720
-
721
- if not download_success:
722
- raise Exception(f"Failed to download tag {tag_name}.")
695
+ download_feather(root_url + filename, path, headers=self.dataset.header, overwrite=True)
723
696
  ordered_tag_paths.append(path)
724
697
  return ordered_tag_paths
725
698
 
@@ -791,7 +764,6 @@ class AtlasMapData:
791
764
  self.projection = projection
792
765
  self.dataset = projection.dataset
793
766
  self.id_field = self.projection.dataset.id_field
794
- self.fields = fields
795
767
  try:
796
768
  # Run fetch_tiles first to guarantee existence of quad feather files
797
769
  self._basic_data: pa.Table = self.projection._fetch_tiles()
@@ -801,29 +773,15 @@ class AtlasMapData:
801
773
  except pa.lib.ArrowInvalid as e: # type: ignore
802
774
  raise ValueError("Failed to fetch tiles for this map")
803
775
 
804
- def _read_prefetched_tiles_with_sidecars(self, additional_sidecars):
776
+ def _read_prefetched_tiles_with_sidecars(self, sidecars):
805
777
  tbs = []
806
- root = feather.read_table(self.projection.tile_destination / Path("0/0/0.feather")) # type: ignore
807
- try:
808
- small_sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()])
809
- except KeyError:
810
- small_sidecars = set([])
811
778
  for path in self.projection._tiles_in_order():
812
779
  tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) # type: ignore
813
780
  for col in tb.column_names:
814
781
  if col[0] == "_":
815
782
  tb = tb.drop([col])
816
- for sidecar_file in small_sidecars:
817
- carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) # type: ignore
818
- for col in carfile.column_names:
819
- tb = tb.append_column(col, carfile[col])
820
- for big_sidecar in additional_sidecars:
821
- fname = (
822
- base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8")
823
- if big_sidecar != "datum_id"
824
- else big_sidecar
825
- )
826
- carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) # type: ignore
783
+ for _, sidecar in sidecars:
784
+ carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar}.feather", memory_map=True) # type: ignore
827
785
  for col in carfile.column_names:
828
786
  tb = tb.append_column(col, carfile[col])
829
787
  tbs.append(tb)
@@ -835,36 +793,31 @@ class AtlasMapData:
835
793
  """
836
794
  Downloads the feather tree for large sidecar columns.
837
795
  """
796
+ logger.info("Downloading dataset")
838
797
  self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
839
798
  root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
840
799
 
841
800
  all_quads = list(self.projection._tiles_in_order(coords_only=True))
842
- sidecars = fields
843
- registered_sidecars = self.projection._registered_sidecars()
844
- if sidecars is None:
845
- sidecars = [
846
- field
847
- for field in self.dataset.dataset_fields
848
- if field not in self._basic_data.column_names and field != "_embeddings"
849
- ]
801
+ sidecars = None
802
+ if fields is None:
803
+ fields = self.dataset.dataset_fields
850
804
  else:
851
- for field in sidecars:
805
+ for field in fields:
852
806
  assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
853
- encoded_sidecars = [base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") for sidecar in sidecars]
854
- if any(sidecar == "datum_id" for (field, sidecar) in registered_sidecars):
855
- sidecars.append("datum_id")
856
- encoded_sidecars.append("datum_id")
807
+
808
+ sidecars = [
809
+ (field, sidecar)
810
+ for field, sidecar in self.projection._registered_sidecars()
811
+ if field[0] != "_" and field in fields
812
+ ]
857
813
 
858
814
  for quad in tqdm(all_quads):
859
- for encoded_colname in encoded_sidecars:
815
+ for field, encoded_colname in sidecars:
860
816
  quad_str = os.path.join(*[str(q) for q in quad])
861
817
  filename = quad_str + "." + encoded_colname + ".feather"
862
818
  path = self.projection.tile_destination / Path(filename)
863
-
864
- if not os.path.exists(path):
865
- # WARNING: Potentially large data request here
866
- download_feather(root + filename, path, headers=self.dataset.header)
867
-
819
+ # WARNING: Potentially large data request here
820
+ download_feather(root + filename, path, headers=self.dataset.header, overwrite=False)
868
821
  return sidecars
869
822
 
870
823
  @property
@@ -36,7 +36,7 @@ from .data_inference import (
36
36
  )
37
37
  from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics
38
38
  from .settings import *
39
- from .utils import assert_valid_project_id, get_object_size_in_bytes
39
+ from .utils import assert_valid_project_id, download_feather
40
40
 
41
41
 
42
42
  class AtlasUser:
@@ -611,6 +611,7 @@ class AtlasProjection:
611
611
  """
612
612
  if self._tile_data is not None:
613
613
  return self._tile_data
614
+ logger.info(f"Downloading files for projection {self.projection_id}")
614
615
  self._download_large_feather(overwrite=overwrite)
615
616
  tbs = []
616
617
  root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True)
@@ -634,8 +635,10 @@ class AtlasProjection:
634
635
 
635
636
  @overload
636
637
  def _tiles_in_order(self, *, coords_only: Literal[False] = ...) -> Iterator[Path]: ...
638
+
637
639
  @overload
638
640
  def _tiles_in_order(self, *, coords_only: Literal[True]) -> Iterator[Tuple[int, int, int]]: ...
641
+
639
642
  @overload
640
643
  def _tiles_in_order(self, *, coords_only: bool) -> Iterator[Any]: ...
641
644
 
@@ -693,27 +696,7 @@ class AtlasProjection:
693
696
  quad = rawquad + ".feather"
694
697
  all_quads.append(quad)
695
698
  path = self.tile_destination / quad
696
-
697
- download_attempt = 0
698
- download_success = False
699
- schema = None
700
- while download_attempt < 3 and not download_success:
701
- download_attempt += 1
702
- if not path.exists() or overwrite:
703
- data = requests.get(root + quad, headers=self.dataset.header)
704
- readable = io.BytesIO(data.content)
705
- readable.seek(0)
706
- tb = feather.read_table(readable, memory_map=True)
707
- path.parent.mkdir(parents=True, exist_ok=True)
708
- feather.write_feather(tb, path)
709
- try:
710
- schema = ipc.open_file(path).schema
711
- download_success = True
712
- except pa.ArrowInvalid:
713
- path.unlink(missing_ok=True)
714
-
715
- if not download_success or schema is None:
716
- raise Exception(f"Failed to download tiles. Aborting...")
699
+ schema = download_feather(root + quad, path, headers=self.dataset.header, overwrite=overwrite)
717
700
 
718
701
  if sidecars is None and b"sidecars" in schema.metadata:
719
702
  # Grab just the filenames
@@ -1160,7 +1143,7 @@ class AtlasDataset(AtlasClass):
1160
1143
 
1161
1144
  build_template = {}
1162
1145
  if self.modality == "embedding":
1163
- if topic_model.community_description_target_field is None:
1146
+ if topic_model.topic_label_field is None:
1164
1147
  logger.warning(
1165
1148
  "You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics."
1166
1149
  )
@@ -1188,7 +1171,7 @@ class AtlasDataset(AtlasClass):
1188
1171
  "topic_model_hyperparameters": json.dumps(
1189
1172
  {
1190
1173
  "build_topic_model": topic_model.build_topic_model,
1191
- "community_description_target_field": topic_model.community_description_target_field,
1174
+ "community_description_target_field": topic_model.topic_label_field, # TODO change key to topic_label_field post v0.0.85
1192
1175
  "cluster_method": topic_model.cluster_method,
1193
1176
  "enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
1194
1177
  }
@@ -1253,7 +1236,7 @@ class AtlasDataset(AtlasClass):
1253
1236
  "topic_model_hyperparameters": json.dumps(
1254
1237
  {
1255
1238
  "build_topic_model": topic_model.build_topic_model,
1256
- "community_description_target_field": indexed_field,
1239
+ "community_description_target_field": indexed_field, # TODO change key to topic_label_field post v0.0.85
1257
1240
  "cluster_method": topic_model.build_topic_model,
1258
1241
  "enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
1259
1242
  }
@@ -4,11 +4,12 @@ import random
4
4
  import sys
5
5
  from io import BytesIO
6
6
  from pathlib import Path
7
- from typing import Optional
7
+ from typing import Optional, Union
8
8
  from uuid import UUID
9
9
 
10
10
  import pyarrow as pa
11
11
  import requests
12
+ from pyarrow import ipc
12
13
 
13
14
  nouns = [
14
15
  "newton",
@@ -241,10 +242,31 @@ def get_object_size_in_bytes(obj):
241
242
 
242
243
  # Helpful function for downloading feather files
243
244
  # Best for small feather files
244
- def download_feather(url: str, path: Path, headers: Optional[dict] = None):
245
- data = requests.get(url, headers=headers)
246
- readable = BytesIO(data.content)
247
- readable.seek(0)
248
- tb = pa.feather.read_table(readable, memory_map=True) # type: ignore
249
- path.parent.mkdir(parents=True, exist_ok=True)
250
- pa.feather.write_feather(tb, path) # type: ignore
245
+ def download_feather(
246
+ url: Union[str, Path], path: Path, headers: Optional[dict] = None, retries=3, overwrite=False
247
+ ) -> pa.Schema:
248
+ """
249
+ Download a feather file from a URL to a local path.
250
+ Returns the schema of feather file if successful.
251
+ """
252
+ assert retries > 0, "Retries must be greater than 0"
253
+ download_attempt = 0
254
+ download_success = False
255
+ schema = None
256
+ while download_attempt < retries and not download_success:
257
+ download_attempt += 1
258
+ if not path.exists() or overwrite:
259
+ data = requests.get(str(url), headers=headers)
260
+ readable = BytesIO(data.content)
261
+ readable.seek(0)
262
+ tb = pa.feather.read_table(readable, memory_map=False) # type: ignore
263
+ path.parent.mkdir(parents=True, exist_ok=True)
264
+ pa.feather.write_feather(tb, path) # type: ignore
265
+ try:
266
+ schema = ipc.open_file(path).schema
267
+ download_success = True
268
+ except pa.ArrowInvalid:
269
+ path.unlink(missing_ok=True)
270
+ if not download_success or schema is None:
271
+ raise ValueError(f"Failed to download feather file from {url} after {retries} attempts.")
272
+ return schema
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nomic
3
- Version: 3.0.31
3
+ Version: 3.0.32
4
4
  Summary: The official Nomic python client.
5
5
  Home-page: https://github.com/nomic-ai/nomic
6
6
  Author: nomic.ai
@@ -20,7 +20,7 @@ sagemaker
20
20
 
21
21
  [dev]
22
22
  nomic[all]
23
- black
23
+ black==24.3.0
24
24
  coverage
25
25
  pylint
26
26
  pytest
@@ -8,7 +8,7 @@ description = "The official Nomic python client."
8
8
 
9
9
  setup(
10
10
  name="nomic",
11
- version="3.0.31",
11
+ version="3.0.32",
12
12
  url="https://github.com/nomic-ai/nomic",
13
13
  description=description,
14
14
  long_description=description,
@@ -43,7 +43,7 @@ setup(
43
43
  ],
44
44
  "dev": [
45
45
  "nomic[all]",
46
- "black",
46
+ "black==24.3.0",
47
47
  "coverage",
48
48
  "pylint",
49
49
  "pytest",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes