nomic 3.0.30__tar.gz → 3.0.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nomic-3.0.30 → nomic-3.0.32}/PKG-INFO +1 -1
- {nomic-3.0.30 → nomic-3.0.32}/nomic/aws/sagemaker.py +0 -1
- {nomic-3.0.30 → nomic-3.0.32}/nomic/data_inference.py +2 -2
- {nomic-3.0.30 → nomic-3.0.32}/nomic/data_operations.py +47 -94
- {nomic-3.0.30 → nomic-3.0.32}/nomic/dataset.py +8 -25
- {nomic-3.0.30 → nomic-3.0.32}/nomic/embed.py +43 -12
- {nomic-3.0.30 → nomic-3.0.32}/nomic/utils.py +30 -8
- {nomic-3.0.30 → nomic-3.0.32}/nomic.egg-info/PKG-INFO +1 -1
- {nomic-3.0.30 → nomic-3.0.32}/nomic.egg-info/requires.txt +1 -1
- {nomic-3.0.30 → nomic-3.0.32}/setup.py +2 -2
- {nomic-3.0.30 → nomic-3.0.32}/README.md +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic/__init__.py +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic/atlas.py +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic/aws/__init__.py +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic/cli.py +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic/pl_callbacks/__init__.py +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic/pl_callbacks/pl_callback.py +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic/settings.py +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic.egg-info/SOURCES.txt +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic.egg-info/dependency_links.txt +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic.egg-info/entry_points.txt +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/nomic.egg-info/top_level.txt +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/pyproject.toml +0 -0
- {nomic-3.0.30 → nomic-3.0.32}/setup.cfg +0 -0
|
@@ -84,11 +84,11 @@ class NomicTopicOptions(BaseModel):
|
|
|
84
84
|
|
|
85
85
|
Args:
|
|
86
86
|
build_topic_model: If True, builds a topic model over your dataset's embeddings.
|
|
87
|
-
|
|
87
|
+
topic_label_field: The dataset column (usually the column you embedded) that Atlas will use to assign a human-readable description to each topic.
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
build_topic_model: bool = True
|
|
91
|
-
|
|
91
|
+
topic_label_field: Optional[str] = Field(default=None, alias="community_description_target_field")
|
|
92
92
|
cluster_method: str = "fast"
|
|
93
93
|
enforce_topic_hierarchy: bool = False
|
|
94
94
|
|
|
@@ -1,13 +1,9 @@
|
|
|
1
1
|
import base64
|
|
2
|
-
import concurrent
|
|
3
|
-
import concurrent.futures
|
|
4
|
-
import glob
|
|
5
2
|
import io
|
|
6
3
|
import json
|
|
7
4
|
import os
|
|
8
5
|
from collections import defaultdict
|
|
9
6
|
from datetime import datetime
|
|
10
|
-
from io import BytesIO
|
|
11
7
|
from pathlib import Path
|
|
12
8
|
from typing import Dict, Iterable, List, Optional, Tuple
|
|
13
9
|
|
|
@@ -17,7 +13,7 @@ import pyarrow as pa
|
|
|
17
13
|
import requests
|
|
18
14
|
from loguru import logger
|
|
19
15
|
from pyarrow import compute as pc
|
|
20
|
-
from pyarrow import feather
|
|
16
|
+
from pyarrow import feather
|
|
21
17
|
from tqdm import tqdm
|
|
22
18
|
|
|
23
19
|
from .settings import EMBEDDING_PAGINATION_LIMIT
|
|
@@ -99,6 +95,7 @@ class AtlasMapTopics:
|
|
|
99
95
|
self._hierarchy = None
|
|
100
96
|
|
|
101
97
|
try:
|
|
98
|
+
logger.info("Downloading topics")
|
|
102
99
|
self._tb: pa.Table = projection._fetch_tiles()
|
|
103
100
|
topic_fields = [column for column in self._tb.column_names if column.startswith("_topic_depth_")]
|
|
104
101
|
self.depth = len(topic_fields)
|
|
@@ -426,53 +423,41 @@ class AtlasMapEmbeddings:
|
|
|
426
423
|
if self._latent is not None:
|
|
427
424
|
return self._latent
|
|
428
425
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
if not root_embedding.exists():
|
|
432
|
-
self._download_latent()
|
|
426
|
+
downloaded_files_in_tile_order = self._download_latent()
|
|
427
|
+
assert len(downloaded_files_in_tile_order) > 0, "No embeddings found for this map."
|
|
433
428
|
all_embeddings = []
|
|
434
429
|
|
|
435
|
-
for path in
|
|
436
|
-
# double with-suffix to remove '.embeddings.feather'
|
|
437
|
-
files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather")
|
|
430
|
+
for path in downloaded_files_in_tile_order:
|
|
438
431
|
# Should there be more than 10, we need to sort by int values, not string values
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
"Could not find any embeddings for tile {}".format(path)
|
|
443
|
-
+ " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'."
|
|
444
|
-
)
|
|
445
|
-
for file in sortable:
|
|
446
|
-
tb = feather.read_table(file, memory_map=True)
|
|
447
|
-
dims = tb["_embeddings"].type.list_size
|
|
448
|
-
all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
|
|
432
|
+
tb = feather.read_table(path, memory_map=True)
|
|
433
|
+
dims = tb["_embeddings"].type.list_size
|
|
434
|
+
all_embeddings.append(pa.compute.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims)) # type: ignore
|
|
449
435
|
return np.vstack(all_embeddings)
|
|
450
436
|
|
|
451
|
-
def _download_latent(self):
|
|
437
|
+
def _download_latent(self) -> List[Path]:
|
|
452
438
|
"""
|
|
453
|
-
Downloads the
|
|
439
|
+
Downloads the feather tree for embeddings.
|
|
440
|
+
Returns the path to downloaded embeddings.
|
|
454
441
|
"""
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
442
|
+
# TODO: Is size of the embedding files (several hundreds of MBs) going to be a problem here?
|
|
443
|
+
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
444
|
+
root_url = Path(
|
|
445
|
+
f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
|
|
446
|
+
)
|
|
459
447
|
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
params = {"projection_id": self.projection.id, "last_file": last, "page_size": limit}
|
|
463
|
-
r = requests.post(route, headers=self.projection.dataset.header, json=params)
|
|
464
|
-
if r.status_code == 204:
|
|
465
|
-
# Download complete!
|
|
466
|
-
break
|
|
467
|
-
fin = BytesIO(r.content)
|
|
468
|
-
tb = feather.read_table(fin, memory_map=True)
|
|
448
|
+
registered_sidecar_names = [sidecar[1] for sidecar in self.projection._registered_sidecars()]
|
|
449
|
+
assert "embeddings" in registered_sidecar_names, "Embeddings not found in sidecars."
|
|
469
450
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
451
|
+
downloaded_files_in_tile_order = []
|
|
452
|
+
logger.info("Downloading latent embeddings...")
|
|
453
|
+
all_quads = list(self.projection._tiles_in_order())
|
|
454
|
+
for quad in tqdm(all_quads):
|
|
455
|
+
path = quad.with_suffix(".embeddings.feather")
|
|
456
|
+
# WARNING: Potentially large data request here
|
|
457
|
+
quadtree_loc = Path(*path.parts[-3:])
|
|
458
|
+
download_feather(root_url / quadtree_loc, path, headers=self.dataset.header, overwrite=False)
|
|
459
|
+
downloaded_files_in_tile_order.append(path)
|
|
460
|
+
return downloaded_files_in_tile_order
|
|
476
461
|
|
|
477
462
|
def vector_search(
|
|
478
463
|
self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5
|
|
@@ -694,6 +679,7 @@ class AtlasMapTags:
|
|
|
694
679
|
"""
|
|
695
680
|
Downloads the feather tree for large sidecar columns.
|
|
696
681
|
"""
|
|
682
|
+
logger.info("Downloading tags")
|
|
697
683
|
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
698
684
|
root_url = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
|
|
699
685
|
|
|
@@ -706,20 +692,7 @@ class AtlasMapTags:
|
|
|
706
692
|
quad_str = os.path.join(*[str(q) for q in quad])
|
|
707
693
|
filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather"
|
|
708
694
|
path = self.projection.tile_destination / Path(filename)
|
|
709
|
-
|
|
710
|
-
download_success = False
|
|
711
|
-
while download_attempt < 3 and not download_success:
|
|
712
|
-
download_attempt += 1
|
|
713
|
-
if not path.exists() or overwrite:
|
|
714
|
-
download_feather(root_url + filename, path, headers=self.dataset.header)
|
|
715
|
-
try:
|
|
716
|
-
ipc.open_file(path).schema
|
|
717
|
-
download_success = True
|
|
718
|
-
except pa.ArrowInvalid:
|
|
719
|
-
path.unlink(missing_ok=True)
|
|
720
|
-
|
|
721
|
-
if not download_success:
|
|
722
|
-
raise Exception(f"Failed to download tag {tag_name}.")
|
|
695
|
+
download_feather(root_url + filename, path, headers=self.dataset.header, overwrite=True)
|
|
723
696
|
ordered_tag_paths.append(path)
|
|
724
697
|
return ordered_tag_paths
|
|
725
698
|
|
|
@@ -791,7 +764,6 @@ class AtlasMapData:
|
|
|
791
764
|
self.projection = projection
|
|
792
765
|
self.dataset = projection.dataset
|
|
793
766
|
self.id_field = self.projection.dataset.id_field
|
|
794
|
-
self.fields = fields
|
|
795
767
|
try:
|
|
796
768
|
# Run fetch_tiles first to guarantee existence of quad feather files
|
|
797
769
|
self._basic_data: pa.Table = self.projection._fetch_tiles()
|
|
@@ -801,29 +773,15 @@ class AtlasMapData:
|
|
|
801
773
|
except pa.lib.ArrowInvalid as e: # type: ignore
|
|
802
774
|
raise ValueError("Failed to fetch tiles for this map")
|
|
803
775
|
|
|
804
|
-
def _read_prefetched_tiles_with_sidecars(self,
|
|
776
|
+
def _read_prefetched_tiles_with_sidecars(self, sidecars):
|
|
805
777
|
tbs = []
|
|
806
|
-
root = feather.read_table(self.projection.tile_destination / Path("0/0/0.feather")) # type: ignore
|
|
807
|
-
try:
|
|
808
|
-
small_sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()])
|
|
809
|
-
except KeyError:
|
|
810
|
-
small_sidecars = set([])
|
|
811
778
|
for path in self.projection._tiles_in_order():
|
|
812
779
|
tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) # type: ignore
|
|
813
780
|
for col in tb.column_names:
|
|
814
781
|
if col[0] == "_":
|
|
815
782
|
tb = tb.drop([col])
|
|
816
|
-
for
|
|
817
|
-
carfile = pa.feather.read_table(path.parent / f"{path.stem}.{
|
|
818
|
-
for col in carfile.column_names:
|
|
819
|
-
tb = tb.append_column(col, carfile[col])
|
|
820
|
-
for big_sidecar in additional_sidecars:
|
|
821
|
-
fname = (
|
|
822
|
-
base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8")
|
|
823
|
-
if big_sidecar != "datum_id"
|
|
824
|
-
else big_sidecar
|
|
825
|
-
)
|
|
826
|
-
carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) # type: ignore
|
|
783
|
+
for _, sidecar in sidecars:
|
|
784
|
+
carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar}.feather", memory_map=True) # type: ignore
|
|
827
785
|
for col in carfile.column_names:
|
|
828
786
|
tb = tb.append_column(col, carfile[col])
|
|
829
787
|
tbs.append(tb)
|
|
@@ -835,36 +793,31 @@ class AtlasMapData:
|
|
|
835
793
|
"""
|
|
836
794
|
Downloads the feather tree for large sidecar columns.
|
|
837
795
|
"""
|
|
796
|
+
logger.info("Downloading dataset")
|
|
838
797
|
self.projection.tile_destination.mkdir(parents=True, exist_ok=True)
|
|
839
798
|
root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.projection.id}/quadtree/"
|
|
840
799
|
|
|
841
800
|
all_quads = list(self.projection._tiles_in_order(coords_only=True))
|
|
842
|
-
sidecars =
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
sidecars = [
|
|
846
|
-
field
|
|
847
|
-
for field in self.dataset.dataset_fields
|
|
848
|
-
if field not in self._basic_data.column_names and field != "_embeddings"
|
|
849
|
-
]
|
|
801
|
+
sidecars = None
|
|
802
|
+
if fields is None:
|
|
803
|
+
fields = self.dataset.dataset_fields
|
|
850
804
|
else:
|
|
851
|
-
for field in
|
|
805
|
+
for field in fields:
|
|
852
806
|
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
807
|
+
|
|
808
|
+
sidecars = [
|
|
809
|
+
(field, sidecar)
|
|
810
|
+
for field, sidecar in self.projection._registered_sidecars()
|
|
811
|
+
if field[0] != "_" and field in fields
|
|
812
|
+
]
|
|
857
813
|
|
|
858
814
|
for quad in tqdm(all_quads):
|
|
859
|
-
for encoded_colname in
|
|
815
|
+
for field, encoded_colname in sidecars:
|
|
860
816
|
quad_str = os.path.join(*[str(q) for q in quad])
|
|
861
817
|
filename = quad_str + "." + encoded_colname + ".feather"
|
|
862
818
|
path = self.projection.tile_destination / Path(filename)
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
# WARNING: Potentially large data request here
|
|
866
|
-
download_feather(root + filename, path, headers=self.dataset.header)
|
|
867
|
-
|
|
819
|
+
# WARNING: Potentially large data request here
|
|
820
|
+
download_feather(root + filename, path, headers=self.dataset.header, overwrite=False)
|
|
868
821
|
return sidecars
|
|
869
822
|
|
|
870
823
|
@property
|
|
@@ -36,7 +36,7 @@ from .data_inference import (
|
|
|
36
36
|
)
|
|
37
37
|
from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics
|
|
38
38
|
from .settings import *
|
|
39
|
-
from .utils import assert_valid_project_id,
|
|
39
|
+
from .utils import assert_valid_project_id, download_feather
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class AtlasUser:
|
|
@@ -611,6 +611,7 @@ class AtlasProjection:
|
|
|
611
611
|
"""
|
|
612
612
|
if self._tile_data is not None:
|
|
613
613
|
return self._tile_data
|
|
614
|
+
logger.info(f"Downloading files for projection {self.projection_id}")
|
|
614
615
|
self._download_large_feather(overwrite=overwrite)
|
|
615
616
|
tbs = []
|
|
616
617
|
root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True)
|
|
@@ -634,8 +635,10 @@ class AtlasProjection:
|
|
|
634
635
|
|
|
635
636
|
@overload
|
|
636
637
|
def _tiles_in_order(self, *, coords_only: Literal[False] = ...) -> Iterator[Path]: ...
|
|
638
|
+
|
|
637
639
|
@overload
|
|
638
640
|
def _tiles_in_order(self, *, coords_only: Literal[True]) -> Iterator[Tuple[int, int, int]]: ...
|
|
641
|
+
|
|
639
642
|
@overload
|
|
640
643
|
def _tiles_in_order(self, *, coords_only: bool) -> Iterator[Any]: ...
|
|
641
644
|
|
|
@@ -693,27 +696,7 @@ class AtlasProjection:
|
|
|
693
696
|
quad = rawquad + ".feather"
|
|
694
697
|
all_quads.append(quad)
|
|
695
698
|
path = self.tile_destination / quad
|
|
696
|
-
|
|
697
|
-
download_attempt = 0
|
|
698
|
-
download_success = False
|
|
699
|
-
schema = None
|
|
700
|
-
while download_attempt < 3 and not download_success:
|
|
701
|
-
download_attempt += 1
|
|
702
|
-
if not path.exists() or overwrite:
|
|
703
|
-
data = requests.get(root + quad, headers=self.dataset.header)
|
|
704
|
-
readable = io.BytesIO(data.content)
|
|
705
|
-
readable.seek(0)
|
|
706
|
-
tb = feather.read_table(readable, memory_map=True)
|
|
707
|
-
path.parent.mkdir(parents=True, exist_ok=True)
|
|
708
|
-
feather.write_feather(tb, path)
|
|
709
|
-
try:
|
|
710
|
-
schema = ipc.open_file(path).schema
|
|
711
|
-
download_success = True
|
|
712
|
-
except pa.ArrowInvalid:
|
|
713
|
-
path.unlink(missing_ok=True)
|
|
714
|
-
|
|
715
|
-
if not download_success or schema is None:
|
|
716
|
-
raise Exception(f"Failed to download tiles. Aborting...")
|
|
699
|
+
schema = download_feather(root + quad, path, headers=self.dataset.header, overwrite=overwrite)
|
|
717
700
|
|
|
718
701
|
if sidecars is None and b"sidecars" in schema.metadata:
|
|
719
702
|
# Grab just the filenames
|
|
@@ -1160,7 +1143,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1160
1143
|
|
|
1161
1144
|
build_template = {}
|
|
1162
1145
|
if self.modality == "embedding":
|
|
1163
|
-
if topic_model.
|
|
1146
|
+
if topic_model.topic_label_field is None:
|
|
1164
1147
|
logger.warning(
|
|
1165
1148
|
"You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics."
|
|
1166
1149
|
)
|
|
@@ -1188,7 +1171,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1188
1171
|
"topic_model_hyperparameters": json.dumps(
|
|
1189
1172
|
{
|
|
1190
1173
|
"build_topic_model": topic_model.build_topic_model,
|
|
1191
|
-
"community_description_target_field": topic_model.
|
|
1174
|
+
"community_description_target_field": topic_model.topic_label_field, # TODO change key to topic_label_field post v0.0.85
|
|
1192
1175
|
"cluster_method": topic_model.cluster_method,
|
|
1193
1176
|
"enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
|
|
1194
1177
|
}
|
|
@@ -1253,7 +1236,7 @@ class AtlasDataset(AtlasClass):
|
|
|
1253
1236
|
"topic_model_hyperparameters": json.dumps(
|
|
1254
1237
|
{
|
|
1255
1238
|
"build_topic_model": topic_model.build_topic_model,
|
|
1256
|
-
"community_description_target_field": indexed_field,
|
|
1239
|
+
"community_description_target_field": indexed_field, # TODO change key to topic_label_field post v0.0.85
|
|
1257
1240
|
"cluster_method": topic_model.build_topic_model,
|
|
1258
1241
|
"enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy,
|
|
1259
1242
|
}
|
|
@@ -7,6 +7,7 @@ import os
|
|
|
7
7
|
import time
|
|
8
8
|
from io import BytesIO
|
|
9
9
|
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence, Tuple, Union, overload
|
|
10
|
+
from urllib.parse import urlparse
|
|
10
11
|
|
|
11
12
|
import PIL
|
|
12
13
|
import PIL.Image
|
|
@@ -307,7 +308,11 @@ def free_embedding_model() -> None:
|
|
|
307
308
|
_embed4all = _embed4all_kwargs = None
|
|
308
309
|
|
|
309
310
|
|
|
310
|
-
def image_api_request(
|
|
311
|
+
def image_api_request(
|
|
312
|
+
images: Optional[List[Tuple[str, bytes]]] = None,
|
|
313
|
+
urls: Optional[List[str]] = None,
|
|
314
|
+
model: str = "nomic-embed-vision-v1",
|
|
315
|
+
):
|
|
311
316
|
global atlas_class
|
|
312
317
|
|
|
313
318
|
assert atlas_class is not None
|
|
@@ -318,7 +323,7 @@ def image_api_request(images: List[Tuple[str, bytes]], model: str = "nomic-embed
|
|
|
318
323
|
lambda: requests.post(
|
|
319
324
|
atlas_url + "/v1/embedding/image",
|
|
320
325
|
headers=atlas_header,
|
|
321
|
-
data={"model": model},
|
|
326
|
+
data={"model": model, "urls": urls},
|
|
322
327
|
files=images,
|
|
323
328
|
)
|
|
324
329
|
)
|
|
@@ -340,7 +345,14 @@ def resize_pil(img):
|
|
|
340
345
|
return img
|
|
341
346
|
|
|
342
347
|
|
|
343
|
-
def
|
|
348
|
+
def _is_valid_url(url):
|
|
349
|
+
if not isinstance(url, str):
|
|
350
|
+
return False
|
|
351
|
+
parsed_url = urlparse(url)
|
|
352
|
+
return all([parsed_url.scheme, parsed_url.netloc])
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def image(images: Sequence[Union[str, PIL.Image.Image]], model: str = "nomic-embed-vision-v1.5"):
|
|
344
356
|
"""
|
|
345
357
|
Generates embeddings for the given images.
|
|
346
358
|
|
|
@@ -362,14 +374,24 @@ def image(images: Sequence[Union[str, PIL.Image.Image]], model: str = "nomic-emb
|
|
|
362
374
|
# if there are fewer images per worker than the max chunksize just split them evenly
|
|
363
375
|
chunksize = min(smallchunk, chunksize)
|
|
364
376
|
|
|
377
|
+
if isinstance(images, str):
|
|
378
|
+
raise TypeError("'images' parameter must be list of strings or PIL images, not str")
|
|
379
|
+
|
|
380
|
+
urls = []
|
|
365
381
|
image_batch = []
|
|
366
382
|
for image in images:
|
|
367
|
-
if isinstance(image, str) and os.path.exists(image):
|
|
368
|
-
img = resize_pil(PIL.Image.open(image))
|
|
369
|
-
buffered = BytesIO()
|
|
370
|
-
img.save(buffered, format="JPEG")
|
|
371
383
|
|
|
372
|
-
|
|
384
|
+
if isinstance(image, str):
|
|
385
|
+
if os.path.exists(image):
|
|
386
|
+
img = resize_pil(PIL.Image.open(image)).convert("RGB")
|
|
387
|
+
buffered = BytesIO()
|
|
388
|
+
img.save(buffered, format="JPEG")
|
|
389
|
+
image_batch.append(("images", buffered.getvalue()))
|
|
390
|
+
elif _is_valid_url(image):
|
|
391
|
+
# Send URL as data
|
|
392
|
+
urls.append(image)
|
|
393
|
+
else:
|
|
394
|
+
raise ValueError(f"Invalid image path or url: {image}")
|
|
373
395
|
|
|
374
396
|
elif isinstance(image, PIL.Image.Image):
|
|
375
397
|
img = resize_pil(image)
|
|
@@ -379,13 +401,22 @@ def image(images: Sequence[Union[str, PIL.Image.Image]], model: str = "nomic-emb
|
|
|
379
401
|
else:
|
|
380
402
|
raise ValueError(f"Not a valid file: {image}")
|
|
381
403
|
|
|
404
|
+
if len(urls) > 0 and len(image_batch) > 0:
|
|
405
|
+
raise ValueError("Provide either urls or image files/objects, not both.")
|
|
406
|
+
|
|
407
|
+
num_images = len(urls) + len(image_batch)
|
|
382
408
|
combined = {"embeddings": [], "usage": {}, "model": model}
|
|
383
409
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
384
410
|
futures = []
|
|
385
|
-
for chunkstart in range(0,
|
|
386
|
-
chunkend = min(
|
|
387
|
-
|
|
388
|
-
|
|
411
|
+
for chunkstart in range(0, num_images, chunksize):
|
|
412
|
+
chunkend = min(num_images, chunkstart + chunksize)
|
|
413
|
+
image_chunk = None
|
|
414
|
+
url_chunk = None
|
|
415
|
+
if len(image_batch) > 0:
|
|
416
|
+
image_chunk = image_batch[chunkstart:chunkend]
|
|
417
|
+
else:
|
|
418
|
+
url_chunk = urls[chunkstart:chunkend]
|
|
419
|
+
futures.append(executor.submit(image_api_request, image_chunk, url_chunk, model))
|
|
389
420
|
|
|
390
421
|
for future in futures:
|
|
391
422
|
response = future.result()
|
|
@@ -4,11 +4,12 @@ import random
|
|
|
4
4
|
import sys
|
|
5
5
|
from io import BytesIO
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Optional
|
|
7
|
+
from typing import Optional, Union
|
|
8
8
|
from uuid import UUID
|
|
9
9
|
|
|
10
10
|
import pyarrow as pa
|
|
11
11
|
import requests
|
|
12
|
+
from pyarrow import ipc
|
|
12
13
|
|
|
13
14
|
nouns = [
|
|
14
15
|
"newton",
|
|
@@ -241,10 +242,31 @@ def get_object_size_in_bytes(obj):
|
|
|
241
242
|
|
|
242
243
|
# Helpful function for downloading feather files
|
|
243
244
|
# Best for small feather files
|
|
244
|
-
def download_feather(
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
245
|
+
def download_feather(
|
|
246
|
+
url: Union[str, Path], path: Path, headers: Optional[dict] = None, retries=3, overwrite=False
|
|
247
|
+
) -> pa.Schema:
|
|
248
|
+
"""
|
|
249
|
+
Download a feather file from a URL to a local path.
|
|
250
|
+
Returns the schema of feather file if successful.
|
|
251
|
+
"""
|
|
252
|
+
assert retries > 0, "Retries must be greater than 0"
|
|
253
|
+
download_attempt = 0
|
|
254
|
+
download_success = False
|
|
255
|
+
schema = None
|
|
256
|
+
while download_attempt < retries and not download_success:
|
|
257
|
+
download_attempt += 1
|
|
258
|
+
if not path.exists() or overwrite:
|
|
259
|
+
data = requests.get(str(url), headers=headers)
|
|
260
|
+
readable = BytesIO(data.content)
|
|
261
|
+
readable.seek(0)
|
|
262
|
+
tb = pa.feather.read_table(readable, memory_map=False) # type: ignore
|
|
263
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
264
|
+
pa.feather.write_feather(tb, path) # type: ignore
|
|
265
|
+
try:
|
|
266
|
+
schema = ipc.open_file(path).schema
|
|
267
|
+
download_success = True
|
|
268
|
+
except pa.ArrowInvalid:
|
|
269
|
+
path.unlink(missing_ok=True)
|
|
270
|
+
if not download_success or schema is None:
|
|
271
|
+
raise ValueError(f"Failed to download feather file from {url} after {retries} attempts.")
|
|
272
|
+
return schema
|
|
@@ -8,7 +8,7 @@ description = "The official Nomic python client."
|
|
|
8
8
|
|
|
9
9
|
setup(
|
|
10
10
|
name="nomic",
|
|
11
|
-
version="3.0.
|
|
11
|
+
version="3.0.32",
|
|
12
12
|
url="https://github.com/nomic-ai/nomic",
|
|
13
13
|
description=description,
|
|
14
14
|
long_description=description,
|
|
@@ -43,7 +43,7 @@ setup(
|
|
|
43
43
|
],
|
|
44
44
|
"dev": [
|
|
45
45
|
"nomic[all]",
|
|
46
|
-
"black",
|
|
46
|
+
"black==24.3.0",
|
|
47
47
|
"coverage",
|
|
48
48
|
"pylint",
|
|
49
49
|
"pytest",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|