vexor 0.19.0a1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vexor/cli.py CHANGED
@@ -31,6 +31,7 @@ from .config import (
31
31
  DEFAULT_MODEL,
32
32
  DEFAULT_PROVIDER,
33
33
  DEFAULT_RERANK,
34
+ SUPPORTED_EXTRACT_BACKENDS,
34
35
  SUPPORTED_PROVIDERS,
35
36
  SUPPORTED_RERANKERS,
36
37
  flashrank_cache_dir,
@@ -389,6 +390,11 @@ def search(
389
390
  "--format",
390
391
  help=Messages.HELP_SEARCH_FORMAT,
391
392
  ),
393
+ no_cache: bool = typer.Option(
394
+ False,
395
+ "--no-cache",
396
+ help=Messages.HELP_NO_CACHE,
397
+ ),
392
398
  ) -> None:
393
399
  """Run the semantic search."""
394
400
  config = load_config()
@@ -396,6 +402,8 @@ def search(
396
402
  model_name = resolve_default_model(provider, config.model)
397
403
  batch_size = config.batch_size if config.batch_size is not None else DEFAULT_BATCH_SIZE
398
404
  embed_concurrency = config.embed_concurrency
405
+ extract_concurrency = config.extract_concurrency
406
+ extract_backend = config.extract_backend
399
407
  base_url = config.base_url
400
408
  api_key = config.api_key
401
409
  auto_index = bool(config.auto_index)
@@ -433,6 +441,8 @@ def search(
433
441
  model_name=model_name,
434
442
  batch_size=batch_size,
435
443
  embed_concurrency=embed_concurrency,
444
+ extract_concurrency=extract_concurrency,
445
+ extract_backend=extract_backend,
436
446
  provider=provider,
437
447
  base_url=base_url,
438
448
  api_key=api_key,
@@ -440,20 +450,35 @@ def search(
440
450
  exclude_patterns=normalized_excludes,
441
451
  extensions=normalized_exts,
442
452
  auto_index=auto_index,
453
+ no_cache=no_cache,
443
454
  rerank=rerank,
444
455
  flashrank_model=flashrank_model,
445
456
  remote_rerank=remote_rerank,
446
457
  )
447
458
  if output_format == SearchOutputFormat.rich:
448
- should_index_first = _should_index_before_search(request) if auto_index else False
449
- if should_index_first:
459
+ if no_cache:
450
460
  console.print(
451
- _styled(Messages.INFO_INDEX_RUNNING.format(path=directory), Styles.INFO)
461
+ _styled(
462
+ Messages.INFO_SEARCH_RUNNING_NO_CACHE.format(path=directory),
463
+ Styles.INFO,
464
+ )
452
465
  )
453
466
  else:
454
- console.print(
455
- _styled(Messages.INFO_SEARCH_RUNNING.format(path=directory), Styles.INFO)
467
+ should_index_first = (
468
+ _should_index_before_search(request) if auto_index else False
456
469
  )
470
+ if should_index_first:
471
+ console.print(
472
+ _styled(
473
+ Messages.INFO_INDEX_RUNNING.format(path=directory), Styles.INFO
474
+ )
475
+ )
476
+ else:
477
+ console.print(
478
+ _styled(
479
+ Messages.INFO_SEARCH_RUNNING.format(path=directory), Styles.INFO
480
+ )
481
+ )
457
482
  try:
458
483
  response = perform_search(request)
459
484
  except FileNotFoundError:
@@ -557,6 +582,8 @@ def index(
557
582
  model_name = resolve_default_model(provider, config.model)
558
583
  batch_size = config.batch_size if config.batch_size is not None else DEFAULT_BATCH_SIZE
559
584
  embed_concurrency = config.embed_concurrency
585
+ extract_concurrency = config.extract_concurrency
586
+ extract_backend = config.extract_backend
560
587
  base_url = config.base_url
561
588
  api_key = config.api_key
562
589
 
@@ -653,6 +680,8 @@ def index(
653
680
  model_name=model_name,
654
681
  batch_size=batch_size,
655
682
  embed_concurrency=embed_concurrency,
683
+ extract_concurrency=extract_concurrency,
684
+ extract_backend=extract_backend,
656
685
  provider=provider,
657
686
  base_url=base_url,
658
687
  api_key=api_key,
@@ -714,6 +743,16 @@ def config(
714
743
  "--set-embed-concurrency",
715
744
  help=Messages.HELP_SET_EMBED_CONCURRENCY,
716
745
  ),
746
+ set_extract_concurrency_option: int | None = typer.Option(
747
+ None,
748
+ "--set-extract-concurrency",
749
+ help=Messages.HELP_SET_EXTRACT_CONCURRENCY,
750
+ ),
751
+ set_extract_backend_option: str | None = typer.Option(
752
+ None,
753
+ "--set-extract-backend",
754
+ help=Messages.HELP_SET_EXTRACT_BACKEND,
755
+ ),
717
756
  set_provider_option: str | None = typer.Option(
718
757
  None,
719
758
  "--set-provider",
@@ -790,6 +829,8 @@ def config(
790
829
  raise typer.BadParameter(Messages.ERROR_BATCH_NEGATIVE)
791
830
  if set_embed_concurrency_option is not None and set_embed_concurrency_option < 1:
792
831
  raise typer.BadParameter(Messages.ERROR_CONCURRENCY_INVALID)
832
+ if set_extract_concurrency_option is not None and set_extract_concurrency_option < 1:
833
+ raise typer.BadParameter(Messages.ERROR_EXTRACT_CONCURRENCY_INVALID)
793
834
  if set_base_url_option and clear_base_url:
794
835
  raise typer.BadParameter(Messages.ERROR_BASE_URL_CONFLICT)
795
836
  flashrank_model_reset = False
@@ -815,6 +856,16 @@ def config(
815
856
  if not normalized_remote_key:
816
857
  raise typer.BadParameter(Messages.ERROR_REMOTE_RERANK_API_KEY_EMPTY)
817
858
  set_remote_rerank_api_key_option = normalized_remote_key
859
+ if set_extract_backend_option is not None:
860
+ normalized_backend = set_extract_backend_option.strip().lower()
861
+ if normalized_backend not in SUPPORTED_EXTRACT_BACKENDS:
862
+ allowed = ", ".join(SUPPORTED_EXTRACT_BACKENDS)
863
+ raise typer.BadParameter(
864
+ Messages.ERROR_EXTRACT_BACKEND_INVALID.format(
865
+ value=set_extract_backend_option, allowed=allowed
866
+ )
867
+ )
868
+ set_extract_backend_option = normalized_backend
818
869
  if clear_remote_rerank and any(
819
870
  (
820
871
  set_remote_rerank_url_option is not None,
@@ -830,6 +881,8 @@ def config(
830
881
  set_model_option is not None,
831
882
  set_batch_option is not None,
832
883
  set_embed_concurrency_option is not None,
884
+ set_extract_concurrency_option is not None,
885
+ set_extract_backend_option is not None,
833
886
  set_provider_option is not None,
834
887
  set_base_url_option is not None,
835
888
  clear_base_url,
@@ -942,6 +995,8 @@ def config(
942
995
  model=set_model_option,
943
996
  batch_size=set_batch_option,
944
997
  embed_concurrency=set_embed_concurrency_option,
998
+ extract_concurrency=set_extract_concurrency_option,
999
+ extract_backend=set_extract_backend_option,
945
1000
  provider=set_provider_option,
946
1001
  base_url=set_base_url_option,
947
1002
  clear_base_url=clear_base_url,
@@ -973,6 +1028,22 @@ def config(
973
1028
  Styles.SUCCESS,
974
1029
  )
975
1030
  )
1031
+ if updates.extract_concurrency_set and set_extract_concurrency_option is not None:
1032
+ console.print(
1033
+ _styled(
1034
+ Messages.INFO_EXTRACT_CONCURRENCY_SET.format(
1035
+ value=set_extract_concurrency_option
1036
+ ),
1037
+ Styles.SUCCESS,
1038
+ )
1039
+ )
1040
+ if updates.extract_backend_set and set_extract_backend_option is not None:
1041
+ console.print(
1042
+ _styled(
1043
+ Messages.INFO_EXTRACT_BACKEND_SET.format(value=set_extract_backend_option),
1044
+ Styles.SUCCESS,
1045
+ )
1046
+ )
976
1047
  if updates.provider_set and set_provider_option is not None:
977
1048
  console.print(
978
1049
  _styled(Messages.INFO_PROVIDER_SET.format(value=set_provider_option), Styles.SUCCESS)
@@ -1119,6 +1190,8 @@ def config(
1119
1190
  model=resolve_default_model(provider, cfg.model),
1120
1191
  batch=cfg.batch_size if cfg.batch_size is not None else DEFAULT_BATCH_SIZE,
1121
1192
  concurrency=cfg.embed_concurrency,
1193
+ extract_concurrency=cfg.extract_concurrency,
1194
+ extract_backend=cfg.extract_backend,
1122
1195
  auto_index="yes" if cfg.auto_index else "no",
1123
1196
  rerank=rerank,
1124
1197
  flashrank_line=flashrank_line,
vexor/config.py CHANGED
@@ -5,23 +5,30 @@ from __future__ import annotations
5
5
  import json
6
6
  import os
7
7
  from dataclasses import dataclass
8
+ from collections.abc import Mapping
8
9
  from pathlib import Path
9
10
  from typing import Any, Dict
10
11
  from urllib.parse import urlparse, urlunparse
11
12
 
12
- CONFIG_DIR = Path(os.path.expanduser("~")) / ".vexor"
13
+ from .text import Messages
14
+
15
+ DEFAULT_CONFIG_DIR = Path(os.path.expanduser("~")) / ".vexor"
16
+ CONFIG_DIR = DEFAULT_CONFIG_DIR
13
17
  CONFIG_FILE = CONFIG_DIR / "config.json"
14
18
  DEFAULT_MODEL = "text-embedding-3-small"
15
19
  DEFAULT_GEMINI_MODEL = "gemini-embedding-001"
16
20
  DEFAULT_LOCAL_MODEL = "intfloat/multilingual-e5-small"
17
21
  DEFAULT_BATCH_SIZE = 64
18
- DEFAULT_EMBED_CONCURRENCY = 2
22
+ DEFAULT_EMBED_CONCURRENCY = 4
23
+ DEFAULT_EXTRACT_CONCURRENCY = max(1, min(4, os.cpu_count() or 1))
24
+ DEFAULT_EXTRACT_BACKEND = "auto"
19
25
  DEFAULT_PROVIDER = "openai"
20
26
  DEFAULT_RERANK = "off"
21
27
  DEFAULT_FLASHRANK_MODEL = "ms-marco-TinyBERT-L-2-v2"
22
28
  DEFAULT_FLASHRANK_MAX_LENGTH = 256
23
29
  SUPPORTED_PROVIDERS: tuple[str, ...] = (DEFAULT_PROVIDER, "gemini", "custom", "local")
24
30
  SUPPORTED_RERANKERS: tuple[str, ...] = ("off", "bm25", "flashrank", "remote")
31
+ SUPPORTED_EXTRACT_BACKENDS: tuple[str, ...] = ("auto", "thread", "process")
25
32
  ENV_API_KEY = "VEXOR_API_KEY"
26
33
  REMOTE_RERANK_ENV = "VEXOR_REMOTE_RERANK_API_KEY"
27
34
  LEGACY_GEMINI_ENV = "GOOGLE_GENAI_API_KEY"
@@ -41,6 +48,8 @@ class Config:
41
48
  model: str = DEFAULT_MODEL
42
49
  batch_size: int = DEFAULT_BATCH_SIZE
43
50
  embed_concurrency: int = DEFAULT_EMBED_CONCURRENCY
51
+ extract_concurrency: int = DEFAULT_EXTRACT_CONCURRENCY
52
+ extract_backend: str = DEFAULT_EXTRACT_BACKEND
44
53
  provider: str = DEFAULT_PROVIDER
45
54
  base_url: str | None = None
46
55
  auto_index: bool = True
@@ -77,6 +86,10 @@ def load_config() -> Config:
77
86
  model=raw.get("model") or DEFAULT_MODEL,
78
87
  batch_size=int(raw.get("batch_size", DEFAULT_BATCH_SIZE)),
79
88
  embed_concurrency=int(raw.get("embed_concurrency", DEFAULT_EMBED_CONCURRENCY)),
89
+ extract_concurrency=int(
90
+ raw.get("extract_concurrency", DEFAULT_EXTRACT_CONCURRENCY)
91
+ ),
92
+ extract_backend=_coerce_extract_backend(raw.get("extract_backend")),
80
93
  provider=raw.get("provider") or DEFAULT_PROVIDER,
81
94
  base_url=raw.get("base_url") or None,
82
95
  auto_index=bool(raw.get("auto_index", True)),
@@ -96,6 +109,8 @@ def save_config(config: Config) -> None:
96
109
  data["model"] = config.model
97
110
  data["batch_size"] = config.batch_size
98
111
  data["embed_concurrency"] = config.embed_concurrency
112
+ data["extract_concurrency"] = config.extract_concurrency
113
+ data["extract_backend"] = config.extract_backend
99
114
  if config.provider:
100
115
  data["provider"] = config.provider
101
116
  if config.base_url:
@@ -129,6 +144,38 @@ def flashrank_cache_dir(*, create: bool = True) -> Path:
129
144
  return cache_dir
130
145
 
131
146
 
147
+ def set_config_dir(path: Path | str | None) -> None:
148
+ global CONFIG_DIR, CONFIG_FILE
149
+ if path is None:
150
+ CONFIG_DIR = DEFAULT_CONFIG_DIR
151
+ else:
152
+ dir_path = Path(path).expanduser().resolve()
153
+ if dir_path.exists() and not dir_path.is_dir():
154
+ raise NotADirectoryError(f"Path is not a directory: {dir_path}")
155
+ CONFIG_DIR = dir_path
156
+ CONFIG_FILE = CONFIG_DIR / "config.json"
157
+
158
+
159
+ def config_from_json(
160
+ payload: str | Mapping[str, object], *, base: Config | None = None
161
+ ) -> Config:
162
+ """Return a Config from a JSON string or mapping without saving it."""
163
+ data = _coerce_config_payload(payload)
164
+ config = Config() if base is None else _clone_config(base)
165
+ _apply_config_payload(config, data)
166
+ return config
167
+
168
+
169
+ def update_config_from_json(
170
+ payload: str | Mapping[str, object], *, replace: bool = False
171
+ ) -> Config:
172
+ """Update config from a JSON string or mapping and persist it."""
173
+ base = None if replace else load_config()
174
+ config = config_from_json(payload, base=base)
175
+ save_config(config)
176
+ return config
177
+
178
+
132
179
  def set_api_key(value: str | None) -> None:
133
180
  config = load_config()
134
181
  config.api_key = value
@@ -153,6 +200,18 @@ def set_embed_concurrency(value: int) -> None:
153
200
  save_config(config)
154
201
 
155
202
 
203
+ def set_extract_concurrency(value: int) -> None:
204
+ config = load_config()
205
+ config.extract_concurrency = value
206
+ save_config(config)
207
+
208
+
209
+ def set_extract_backend(value: str) -> None:
210
+ config = load_config()
211
+ config.extract_backend = _normalize_extract_backend(value)
212
+ save_config(config)
213
+
214
+
156
215
  def set_provider(value: str) -> None:
157
216
  config = load_config()
158
217
  config.provider = value
@@ -281,3 +340,182 @@ def resolve_remote_rerank_api_key(configured: str | None) -> str | None:
281
340
  if env_key:
282
341
  return env_key
283
342
  return None
343
+
344
+
345
+ def _coerce_config_payload(payload: str | Mapping[str, object]) -> Mapping[str, object]:
346
+ if isinstance(payload, str):
347
+ try:
348
+ data = json.loads(payload)
349
+ except json.JSONDecodeError as exc:
350
+ raise ValueError(Messages.ERROR_CONFIG_JSON_INVALID) from exc
351
+ elif isinstance(payload, Mapping):
352
+ data = dict(payload)
353
+ else:
354
+ raise ValueError(Messages.ERROR_CONFIG_JSON_INVALID)
355
+ if not isinstance(data, Mapping):
356
+ raise ValueError(Messages.ERROR_CONFIG_JSON_INVALID)
357
+ return data
358
+
359
+
360
+ def _clone_config(config: Config) -> Config:
361
+ remote = config.remote_rerank
362
+ return Config(
363
+ api_key=config.api_key,
364
+ model=config.model,
365
+ batch_size=config.batch_size,
366
+ embed_concurrency=config.embed_concurrency,
367
+ extract_concurrency=config.extract_concurrency,
368
+ extract_backend=config.extract_backend,
369
+ provider=config.provider,
370
+ base_url=config.base_url,
371
+ auto_index=config.auto_index,
372
+ local_cuda=config.local_cuda,
373
+ rerank=config.rerank,
374
+ flashrank_model=config.flashrank_model,
375
+ remote_rerank=(
376
+ None
377
+ if remote is None
378
+ else RemoteRerankConfig(
379
+ base_url=remote.base_url,
380
+ api_key=remote.api_key,
381
+ model=remote.model,
382
+ )
383
+ ),
384
+ )
385
+
386
+
387
+ def _apply_config_payload(config: Config, payload: Mapping[str, object]) -> None:
388
+ if "api_key" in payload:
389
+ config.api_key = _coerce_optional_str(payload["api_key"], "api_key")
390
+ if "model" in payload:
391
+ config.model = _coerce_required_str(payload["model"], "model", DEFAULT_MODEL)
392
+ if "batch_size" in payload:
393
+ config.batch_size = _coerce_int(
394
+ payload["batch_size"], "batch_size", DEFAULT_BATCH_SIZE
395
+ )
396
+ if "embed_concurrency" in payload:
397
+ config.embed_concurrency = _coerce_int(
398
+ payload["embed_concurrency"],
399
+ "embed_concurrency",
400
+ DEFAULT_EMBED_CONCURRENCY,
401
+ )
402
+ if "extract_concurrency" in payload:
403
+ config.extract_concurrency = _coerce_int(
404
+ payload["extract_concurrency"],
405
+ "extract_concurrency",
406
+ DEFAULT_EXTRACT_CONCURRENCY,
407
+ )
408
+ if "extract_backend" in payload:
409
+ config.extract_backend = _normalize_extract_backend(payload["extract_backend"])
410
+ if "provider" in payload:
411
+ config.provider = _coerce_required_str(
412
+ payload["provider"], "provider", DEFAULT_PROVIDER
413
+ )
414
+ if "base_url" in payload:
415
+ config.base_url = _coerce_optional_str(payload["base_url"], "base_url")
416
+ if "auto_index" in payload:
417
+ config.auto_index = _coerce_bool(payload["auto_index"], "auto_index")
418
+ if "local_cuda" in payload:
419
+ config.local_cuda = _coerce_bool(payload["local_cuda"], "local_cuda")
420
+ if "rerank" in payload:
421
+ config.rerank = _normalize_rerank(payload["rerank"])
422
+ if "flashrank_model" in payload:
423
+ config.flashrank_model = _coerce_optional_str(
424
+ payload["flashrank_model"], "flashrank_model"
425
+ )
426
+ if "remote_rerank" in payload:
427
+ config.remote_rerank = _coerce_remote_rerank(payload["remote_rerank"])
428
+
429
+
430
+ def _coerce_optional_str(value: object, field: str) -> str | None:
431
+ if value is None:
432
+ return None
433
+ if isinstance(value, str):
434
+ cleaned = value.strip()
435
+ return cleaned or None
436
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field=field))
437
+
438
+
439
+ def _coerce_required_str(value: object, field: str, default: str) -> str:
440
+ if value is None:
441
+ return default
442
+ if isinstance(value, str):
443
+ cleaned = value.strip()
444
+ return cleaned or default
445
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field=field))
446
+
447
+
448
+ def _coerce_int(value: object, field: str, default: int) -> int:
449
+ if value is None:
450
+ return default
451
+ if isinstance(value, bool):
452
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field=field))
453
+ if isinstance(value, int):
454
+ return value
455
+ if isinstance(value, float):
456
+ if value.is_integer():
457
+ return int(value)
458
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field=field))
459
+ if isinstance(value, str):
460
+ cleaned = value.strip()
461
+ if not cleaned:
462
+ return default
463
+ try:
464
+ return int(cleaned)
465
+ except ValueError as exc:
466
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field=field)) from exc
467
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field=field))
468
+
469
+
470
+ def _coerce_bool(value: object, field: str) -> bool:
471
+ if isinstance(value, bool):
472
+ return value
473
+ if isinstance(value, int) and value in (0, 1):
474
+ return bool(value)
475
+ if isinstance(value, str):
476
+ cleaned = value.strip().lower()
477
+ if cleaned in {"true", "1", "yes", "on"}:
478
+ return True
479
+ if cleaned in {"false", "0", "no", "off"}:
480
+ return False
481
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field=field))
482
+
483
+
484
+ def _normalize_extract_backend(value: object) -> str:
485
+ if value is None:
486
+ return DEFAULT_EXTRACT_BACKEND
487
+ if isinstance(value, str):
488
+ normalized = value.strip().lower() or DEFAULT_EXTRACT_BACKEND
489
+ if normalized in SUPPORTED_EXTRACT_BACKENDS:
490
+ return normalized
491
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field="extract_backend"))
492
+
493
+
494
+ def _coerce_extract_backend(value: object) -> str:
495
+ if value is None:
496
+ return DEFAULT_EXTRACT_BACKEND
497
+ if isinstance(value, str):
498
+ normalized = value.strip().lower()
499
+ if normalized in SUPPORTED_EXTRACT_BACKENDS:
500
+ return normalized
501
+ return DEFAULT_EXTRACT_BACKEND
502
+
503
+
504
+ def _normalize_rerank(value: object) -> str:
505
+ if value is None:
506
+ normalized = DEFAULT_RERANK
507
+ elif isinstance(value, str):
508
+ normalized = value.strip().lower() or DEFAULT_RERANK
509
+ else:
510
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field="rerank"))
511
+ if normalized not in SUPPORTED_RERANKERS:
512
+ normalized = DEFAULT_RERANK
513
+ return normalized
514
+
515
+
516
+ def _coerce_remote_rerank(value: object) -> RemoteRerankConfig | None:
517
+ if value is None:
518
+ return None
519
+ if isinstance(value, Mapping):
520
+ return _parse_remote_rerank(dict(value))
521
+ raise ValueError(Messages.ERROR_CONFIG_VALUE_INVALID.format(field="remote_rerank"))
vexor/providers/gemini.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ import time
6
7
  from typing import Iterator, Sequence
7
8
 
8
9
  import numpy as np
@@ -38,14 +39,19 @@ class GeminiEmbeddingBackend:
38
39
  if base_url:
39
40
  client_kwargs["http_options"] = genai_types.HttpOptions(base_url=base_url)
40
41
  self._client = genai.Client(**client_kwargs)
42
+ self._executor: ThreadPoolExecutor | None = None
41
43
 
42
44
  def embed(self, texts: Sequence[str]) -> np.ndarray:
43
45
  if not texts:
44
46
  return np.empty((0, 0), dtype=np.float32)
45
- batches = list(_chunk(texts, self.chunk_size))
46
- if self.concurrency > 1 and len(batches) > 1:
47
- vectors_by_batch: list[list[np.ndarray] | None] = [None] * len(batches)
48
- with ThreadPoolExecutor(max_workers=min(self.concurrency, len(batches))) as executor:
47
+ if self.concurrency > 1:
48
+ batches = list(_chunk(texts, self.chunk_size))
49
+ if len(batches) > 1:
50
+ vectors_by_batch: list[list[np.ndarray] | None] = [None] * len(batches)
51
+ executor = self._executor
52
+ if executor is None:
53
+ executor = ThreadPoolExecutor(max_workers=self.concurrency)
54
+ self._executor = executor
49
55
  future_map = {
50
56
  executor.submit(self._embed_batch, batch): idx
51
57
  for idx, batch in enumerate(batches)
@@ -53,23 +59,34 @@ class GeminiEmbeddingBackend:
53
59
  for future in as_completed(future_map):
54
60
  idx = future_map[future]
55
61
  vectors_by_batch[idx] = future.result()
56
- vectors = [vec for batch in vectors_by_batch if batch for vec in batch]
62
+ vectors = [vec for batch in vectors_by_batch if batch for vec in batch]
63
+ else:
64
+ vectors = []
65
+ for batch in batches:
66
+ vectors.extend(self._embed_batch(batch))
57
67
  else:
58
68
  vectors = []
59
- for batch in batches:
69
+ for batch in _chunk(texts, self.chunk_size):
60
70
  vectors.extend(self._embed_batch(batch))
61
71
  if not vectors:
62
72
  raise RuntimeError(Messages.ERROR_NO_EMBEDDINGS)
63
73
  return np.vstack(vectors)
64
74
 
65
75
  def _embed_batch(self, batch: Sequence[str]) -> list[np.ndarray]:
66
- try:
67
- response = self._client.models.embed_content(
68
- model=self.model_name,
69
- contents=list(batch),
70
- )
71
- except genai_errors.ClientError as exc:
72
- raise RuntimeError(_format_genai_error(exc)) from exc
76
+ attempt = 0
77
+ while True:
78
+ try:
79
+ response = self._client.models.embed_content(
80
+ model=self.model_name,
81
+ contents=list(batch),
82
+ )
83
+ break
84
+ except genai_errors.ClientError as exc:
85
+ if _should_retry_genai_error(exc) and attempt < _MAX_RETRIES:
86
+ _sleep(_backoff_delay(attempt))
87
+ attempt += 1
88
+ continue
89
+ raise RuntimeError(_format_genai_error(exc)) from exc
73
90
  embeddings = getattr(response, "embeddings", None)
74
91
  if not embeddings:
75
92
  raise RuntimeError(Messages.ERROR_NO_EMBEDDINGS)
@@ -90,6 +107,55 @@ def _chunk(items: Sequence[str], size: int | None) -> Iterator[Sequence[str]]:
90
107
  yield items[idx : idx + size]
91
108
 
92
109
 
110
+ _RETRYABLE_STATUS_CODES = {408, 429, 500, 502, 503, 504}
111
+ _MAX_RETRIES = 2
112
+ _RETRY_BASE_DELAY = 0.5
113
+ _RETRY_MAX_DELAY = 4.0
114
+
115
+
116
+ def _sleep(seconds: float) -> None:
117
+ time.sleep(seconds)
118
+
119
+
120
+ def _backoff_delay(attempt: int) -> float:
121
+ return min(_RETRY_MAX_DELAY, _RETRY_BASE_DELAY * (2**attempt))
122
+
123
+
124
+ def _extract_status_code(exc: Exception) -> int | None:
125
+ for attr in ("status_code", "status", "http_status"):
126
+ value = getattr(exc, attr, None)
127
+ if isinstance(value, int):
128
+ return value
129
+ response = getattr(exc, "response", None)
130
+ if response is not None:
131
+ value = getattr(response, "status_code", None)
132
+ if isinstance(value, int):
133
+ return value
134
+ return None
135
+
136
+
137
+ def _should_retry_genai_error(exc: Exception) -> bool:
138
+ status = _extract_status_code(exc)
139
+ if status in _RETRYABLE_STATUS_CODES:
140
+ return True
141
+ name = exc.__class__.__name__.lower()
142
+ if "ratelimit" in name or "timeout" in name or "temporarily" in name:
143
+ return True
144
+ message = str(exc).lower()
145
+ return any(
146
+ token in message
147
+ for token in (
148
+ "rate limit",
149
+ "timeout",
150
+ "temporar",
151
+ "overload",
152
+ "try again",
153
+ "too many requests",
154
+ "service unavailable",
155
+ )
156
+ )
157
+
158
+
93
159
  def _format_genai_error(exc: genai_errors.ClientError) -> str:
94
160
  message = getattr(exc, "message", None) or str(exc)
95
161
  if "API key" in message: