deepeval 3.4.7__py3-none-any.whl → 3.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. deepeval/__init__.py +8 -7
  2. deepeval/_version.py +1 -1
  3. deepeval/cli/dotenv_handler.py +71 -0
  4. deepeval/cli/main.py +1021 -280
  5. deepeval/cli/utils.py +116 -2
  6. deepeval/confident/api.py +29 -14
  7. deepeval/config/__init__.py +0 -0
  8. deepeval/config/settings.py +565 -0
  9. deepeval/config/settings_manager.py +133 -0
  10. deepeval/config/utils.py +86 -0
  11. deepeval/dataset/__init__.py +1 -0
  12. deepeval/dataset/dataset.py +70 -10
  13. deepeval/dataset/test_run_tracer.py +82 -0
  14. deepeval/dataset/utils.py +23 -0
  15. deepeval/key_handler.py +64 -2
  16. deepeval/metrics/__init__.py +4 -1
  17. deepeval/metrics/answer_relevancy/template.py +7 -2
  18. deepeval/metrics/conversational_dag/__init__.py +7 -0
  19. deepeval/metrics/conversational_dag/conversational_dag.py +139 -0
  20. deepeval/metrics/conversational_dag/nodes.py +931 -0
  21. deepeval/metrics/conversational_dag/templates.py +117 -0
  22. deepeval/metrics/dag/dag.py +13 -4
  23. deepeval/metrics/dag/graph.py +47 -15
  24. deepeval/metrics/dag/utils.py +103 -38
  25. deepeval/metrics/faithfulness/template.py +11 -8
  26. deepeval/metrics/multimodal_metrics/multimodal_answer_relevancy/template.py +6 -4
  27. deepeval/metrics/multimodal_metrics/multimodal_faithfulness/template.py +6 -4
  28. deepeval/metrics/tool_correctness/tool_correctness.py +7 -3
  29. deepeval/models/llms/amazon_bedrock_model.py +24 -3
  30. deepeval/models/llms/openai_model.py +37 -41
  31. deepeval/models/retry_policy.py +280 -0
  32. deepeval/openai_agents/agent.py +4 -2
  33. deepeval/synthesizer/chunking/doc_chunker.py +87 -51
  34. deepeval/test_run/api.py +1 -0
  35. deepeval/tracing/otel/exporter.py +20 -8
  36. deepeval/tracing/otel/utils.py +57 -0
  37. deepeval/tracing/tracing.py +37 -16
  38. deepeval/tracing/utils.py +98 -1
  39. deepeval/utils.py +111 -70
  40. {deepeval-3.4.7.dist-info → deepeval-3.4.9.dist-info}/METADATA +3 -1
  41. {deepeval-3.4.7.dist-info → deepeval-3.4.9.dist-info}/RECORD +44 -34
  42. deepeval/env.py +0 -35
  43. {deepeval-3.4.7.dist-info → deepeval-3.4.9.dist-info}/LICENSE.md +0 -0
  44. {deepeval-3.4.7.dist-info → deepeval-3.4.9.dist-info}/WHEEL +0 -0
  45. {deepeval-3.4.7.dist-info → deepeval-3.4.9.dist-info}/entry_points.txt +0 -0
@@ -1,26 +1,33 @@
1
+ import logging
2
+
1
3
  from openai.types.chat.chat_completion import ChatCompletion
2
4
  from deepeval.key_handler import ModelKeyValues, KEY_FILE_HANDLER
3
5
  from typing import Optional, Tuple, Union, Dict
4
- from openai import OpenAI, AsyncOpenAI
5
6
  from pydantic import BaseModel
6
- import logging
7
- import openai
8
7
 
9
- from tenacity import (
10
- retry,
11
- retry_if_exception_type,
12
- wait_exponential_jitter,
13
- RetryCallState,
8
+ from openai import (
9
+ OpenAI,
10
+ AsyncOpenAI,
14
11
  )
15
12
 
13
+ from tenacity import retry, RetryCallState, before_sleep_log
14
+
16
15
  from deepeval.models import DeepEvalBaseLLM
17
16
  from deepeval.models.llms.utils import trim_and_load_json
18
17
  from deepeval.models.utils import parse_model_name
18
+ from deepeval.models.retry_policy import (
19
+ OPENAI_ERROR_POLICY,
20
+ default_wait,
21
+ default_stop,
22
+ retry_predicate,
23
+ )
24
+
25
+ logger = logging.getLogger("deepeval.openai_model")
19
26
 
20
27
 
21
28
  def log_retry_error(retry_state: RetryCallState):
22
29
  exception = retry_state.outcome.exception()
23
- logging.error(
30
+ logger.error(
24
31
  f"OpenAI Error: {exception} Retrying: {retry_state.attempt_number} time(s)..."
25
32
  )
26
33
 
@@ -212,14 +219,22 @@ models_requiring_temperature_1 = [
212
219
  "gpt-5-chat-latest",
213
220
  ]
214
221
 
215
- retryable_exceptions = (
216
- openai.RateLimitError,
217
- openai.APIConnectionError,
218
- openai.APITimeoutError,
219
- openai.LengthFinishReasonError,
222
+ _base_retry_rules_kw = dict(
223
+ wait=default_wait(),
224
+ stop=default_stop(),
225
+ retry=retry_predicate(OPENAI_ERROR_POLICY),
226
+ before_sleep=before_sleep_log(
227
+ logger, logging.INFO
228
+ ), # <- logs only on retries
229
+ after=log_retry_error,
220
230
  )
221
231
 
222
232
 
233
+ def _openai_client_kwargs():
234
+ # Avoid double-retry at SDK layer by disabling the SDK's own retries so tenacity is the single source of truth for retry logic.
235
+ return {"max_retries": 0}
236
+
237
+
223
238
  class GPTModel(DeepEvalBaseLLM):
224
239
  def __init__(
225
240
  self,
@@ -296,11 +311,7 @@ class GPTModel(DeepEvalBaseLLM):
296
311
  # Generate functions
297
312
  ###############################################
298
313
 
299
- @retry(
300
- wait=wait_exponential_jitter(initial=1, exp_base=2, jitter=2, max=10),
301
- retry=retry_if_exception_type(retryable_exceptions),
302
- after=log_retry_error,
303
- )
314
+ @retry(**_base_retry_rules_kw)
304
315
  def generate(
305
316
  self, prompt: str, schema: Optional[BaseModel] = None
306
317
  ) -> Tuple[Union[str, Dict], float]:
@@ -359,11 +370,7 @@ class GPTModel(DeepEvalBaseLLM):
359
370
  else:
360
371
  return output, cost
361
372
 
362
- @retry(
363
- wait=wait_exponential_jitter(initial=1, exp_base=2, jitter=2, max=10),
364
- retry=retry_if_exception_type(retryable_exceptions),
365
- after=log_retry_error,
366
- )
373
+ @retry(**_base_retry_rules_kw)
367
374
  async def a_generate(
368
375
  self, prompt: str, schema: Optional[BaseModel] = None
369
376
  ) -> Tuple[Union[str, BaseModel], float]:
@@ -427,11 +434,7 @@ class GPTModel(DeepEvalBaseLLM):
427
434
  # Other generate functions
428
435
  ###############################################
429
436
 
430
- @retry(
431
- wait=wait_exponential_jitter(initial=1, exp_base=2, jitter=2, max=10),
432
- retry=retry_if_exception_type(retryable_exceptions),
433
- after=log_retry_error,
434
- )
437
+ @retry(**_base_retry_rules_kw)
435
438
  def generate_raw_response(
436
439
  self,
437
440
  prompt: str,
@@ -454,11 +457,7 @@ class GPTModel(DeepEvalBaseLLM):
454
457
 
455
458
  return completion, cost
456
459
 
457
- @retry(
458
- wait=wait_exponential_jitter(initial=1, exp_base=2, jitter=2, max=10),
459
- retry=retry_if_exception_type(retryable_exceptions),
460
- after=log_retry_error,
461
- )
460
+ @retry(**_base_retry_rules_kw)
462
461
  async def a_generate_raw_response(
463
462
  self,
464
463
  prompt: str,
@@ -481,11 +480,7 @@ class GPTModel(DeepEvalBaseLLM):
481
480
 
482
481
  return completion, cost
483
482
 
484
- @retry(
485
- wait=wait_exponential_jitter(initial=1, exp_base=2, jitter=2, max=10),
486
- retry=retry_if_exception_type(retryable_exceptions),
487
- after=log_retry_error,
488
- )
483
+ @retry(**_base_retry_rules_kw)
489
484
  def generate_samples(
490
485
  self, prompt: str, n: int, temperature: float
491
486
  ) -> Tuple[list[str], float]:
@@ -518,12 +513,13 @@ class GPTModel(DeepEvalBaseLLM):
518
513
  return self.model_name
519
514
 
520
515
  def load_model(self, async_mode: bool = False):
516
+ kwargs = {**self.kwargs, **_openai_client_kwargs()}
521
517
  if not async_mode:
522
518
  return OpenAI(
523
519
  api_key=self._openai_api_key,
524
520
  base_url=self.base_url,
525
- **self.kwargs,
521
+ **kwargs,
526
522
  )
527
523
  return AsyncOpenAI(
528
- api_key=self._openai_api_key, base_url=self.base_url, **self.kwargs
524
+ api_key=self._openai_api_key, base_url=self.base_url, **kwargs
529
525
  )
@@ -0,0 +1,280 @@
1
+ """Generic retry policy helpers for provider SDKs.
2
+
3
+ This module lets models define *what is transient* vs *non-retryable* (permanent) failure
4
+ without coupling to a specific SDK. You provide an `ErrorPolicy` describing
5
+ exception classes and special “non-retryable” error codes, such as quota-exhausted from OpenAI,
6
+ and get back a Tenacity predicate suitable for `retry_if_exception`.
7
+
8
+ Typical use:
9
+
10
+ # Import dependencies
11
+ from tenacity import retry, before_sleep_log
12
+ from deepeval.models.retry_policy import (
13
+ OPENAI_ERROR_POLICY, default_wait, default_stop, retry_predicate
14
+ )
15
+
16
+ # Define retry rule keywords
17
+ _retry_kw = dict(
18
+ wait=default_wait(),
19
+ stop=default_stop(),
20
+ retry=retry_predicate(OPENAI_ERROR_POLICY),
21
+ before_sleep=before_sleep_log(logger, logging.INFO), # <- Optional: logs only on retries
22
+ )
23
+
24
+ # Apply retry rule keywords where desired
25
+ @retry(**_retry_kw)
26
+ def call_openai(...):
27
+ ...
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import logging
33
+
34
+ from deepeval.utils import read_env_int, read_env_float
35
+ from dataclasses import dataclass, field
36
+ from typing import Iterable, Mapping, Callable, Sequence, Tuple
37
+ from collections.abc import Mapping as ABCMapping
38
+ from tenacity import (
39
+ wait_exponential_jitter,
40
+ stop_after_attempt,
41
+ retry_if_exception,
42
+ )
43
+
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # --------------------------
48
+ # Policy description
49
+ # --------------------------
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class ErrorPolicy:
54
+ """Describe exception classes & rules for retry classification.
55
+
56
+ Attributes:
57
+ auth_excs: Exceptions that indicate authentication/authorization problems.
58
+ These are treated as non-retryable.
59
+ rate_limit_excs: Exceptions representing rate limiting (HTTP 429).
60
+ network_excs: Exceptions for timeouts / connection issues (transient).
61
+ http_excs: Exceptions carrying an integer `status_code` (4xx, 5xx)
62
+ non_retryable_codes: Error “code” strings that should be considered permanent,
63
+ such as "insufficient_quota". Used to refine rate-limit handling.
64
+ retry_5xx: Whether to retry provider 5xx responses (defaults to True).
65
+ """
66
+
67
+ auth_excs: Tuple[type[Exception], ...]
68
+ rate_limit_excs: Tuple[type[Exception], ...]
69
+ network_excs: Tuple[type[Exception], ...]
70
+ http_excs: Tuple[type[Exception], ...]
71
+ non_retryable_codes: frozenset[str] = field(default_factory=frozenset)
72
+ retry_5xx: bool = True
73
+ message_markers: Mapping[str, Iterable[str]] = field(default_factory=dict)
74
+
75
+
76
+ # --------------------------
77
+ # Extraction helpers
78
+ # --------------------------
79
+
80
+
81
+ def extract_error_code(
82
+ e: Exception,
83
+ *,
84
+ response_attr: str = "response",
85
+ body_attr: str = "body",
86
+ code_path: Sequence[str] = ("error", "code"),
87
+ message_markers: Mapping[str, Iterable[str]] | None = None,
88
+ ) -> str:
89
+ """Best effort extraction of an error 'code' for SDK compatibility.
90
+
91
+ Order of attempts:
92
+ 1) Structured JSON via `e.response.json()` (typical HTTP error payload).
93
+ 2) A dict stored on `e.body` (some gateways/proxies use this).
94
+ 3) Message sniffing fallback, using `message_markers`.
95
+
96
+ Args:
97
+ e: The exception raised by the SDK/provider client.
98
+ response_attr: Attribute name that holds an HTTP response object.
99
+ body_attr: Attribute name that may hold a parsed payload (dict).
100
+ code_path: Path of keys to traverse to the code (e.g., ["error", "code"]).
101
+ message_markers: Mapping from canonical code -> substrings to search for.
102
+
103
+ Returns:
104
+ The code string if found, else "".
105
+ """
106
+ # 1) Structured JSON in e.response.json()
107
+ resp = getattr(e, response_attr, None)
108
+ if resp is not None:
109
+ try:
110
+ cur = resp.json()
111
+ for k in code_path:
112
+ if not isinstance(cur, ABCMapping):
113
+ cur = {}
114
+ break
115
+ cur = cur.get(k, {})
116
+ if isinstance(cur, (str, int)):
117
+ return str(cur)
118
+ except Exception:
119
+ # response.json() can raise; ignore and fall through
120
+ pass
121
+
122
+ # 2) SDK provided dict body
123
+ body = getattr(e, body_attr, None)
124
+ if isinstance(body, ABCMapping):
125
+ cur = body
126
+ for k in code_path:
127
+ if not isinstance(cur, ABCMapping):
128
+ cur = {}
129
+ break
130
+ cur = cur.get(k, {})
131
+ if isinstance(cur, (str, int)):
132
+ return str(cur)
133
+
134
+ # 3) Message sniff (hopefully this helps catch message codes that slip past the previous 2 parsers)
135
+ msg = str(e).lower()
136
+ markers = message_markers or {}
137
+ for code_key, needles in markers.items():
138
+ if any(n in msg for n in needles):
139
+ return code_key
140
+
141
+ return ""
142
+
143
+
144
+ # --------------------------
145
+ # Predicate factory
146
+ # --------------------------
147
+
148
+
149
+ def make_is_transient(
150
+ policy: ErrorPolicy,
151
+ *,
152
+ message_markers: Mapping[str, Iterable[str]] | None = None,
153
+ extra_non_retryable_codes: Iterable[str] = (),
154
+ ) -> Callable[[Exception], bool]:
155
+ """Create a Tenacity predicate: True = retry, False = surface immediately.
156
+
157
+ Semantics:
158
+ - Auth errors: non-retryable.
159
+ - Rate limit errors: retry unless the extracted code is in the non-retryable set
160
+ - Network/timeout errors: retry.
161
+ - HTTP errors with a `status_code`: retry 5xx if `policy.retry_5xx` is True.
162
+ - Everything else: treated as non-retryable.
163
+
164
+ Args:
165
+ policy: An ErrorPolicy describing error classes and rules.
166
+ message_markers: Optional override/extension for code inference via message text.
167
+ extra_non_retryable_codes: Additional code strings to treat as non-retryable.
168
+
169
+ Returns:
170
+ A callable `predicate(e) -> bool` suitable for `retry_if_exception`.
171
+ """
172
+ non_retryable = frozenset(policy.non_retryable_codes) | frozenset(
173
+ extra_non_retryable_codes
174
+ )
175
+
176
+ def _pred(e: Exception) -> bool:
177
+ if isinstance(e, policy.auth_excs):
178
+ return False
179
+
180
+ if isinstance(e, policy.rate_limit_excs):
181
+ code = extract_error_code(
182
+ e, message_markers=(message_markers or policy.message_markers)
183
+ )
184
+ return code not in non_retryable
185
+
186
+ if isinstance(e, policy.network_excs):
187
+ return True
188
+
189
+ if isinstance(e, policy.http_excs):
190
+ try:
191
+ sc = int(getattr(e, "status_code", 0))
192
+ except Exception:
193
+ sc = 0
194
+ return policy.retry_5xx and 500 <= sc < 600
195
+
196
+ return False
197
+
198
+ return _pred
199
+
200
+
201
+ # --------------------------
202
+ # Tenacity convenience
203
+ # --------------------------
204
+
205
+
206
+ def default_wait():
207
+ """Default backoff: exponential with jitter, capped.
208
+ Overridable via env:
209
+ - DEEPEVAL_RETRY_INITIAL_SECONDS (>=0)
210
+ - DEEPEVAL_RETRY_EXP_BASE (>=1)
211
+ - DEEPEVAL_RETRY_JITTER (>=0)
212
+ - DEEPEVAL_RETRY_CAP_SECONDS (>=0)
213
+ """
214
+ initial = read_env_float(
215
+ "DEEPEVAL_RETRY_INITIAL_SECONDS", 1.0, min_value=0.0
216
+ )
217
+ exp_base = read_env_float("DEEPEVAL_RETRY_EXP_BASE", 2.0, min_value=1.0)
218
+ jitter = read_env_float("DEEPEVAL_RETRY_JITTER", 2.0, min_value=0.0)
219
+ cap = read_env_float("DEEPEVAL_RETRY_CAP_SECONDS", 5.0, min_value=0.0)
220
+ return wait_exponential_jitter(
221
+ initial=initial, exp_base=exp_base, jitter=jitter, max=cap
222
+ )
223
+
224
+
225
+ def default_stop():
226
+ """Default stop condition: at most N attempts (N-1 retries).
227
+ Overridable via env:
228
+ - DEEPEVAL_RETRY_MAX_ATTEMPTS (>=1)
229
+ """
230
+ attempts = read_env_int("DEEPEVAL_RETRY_MAX_ATTEMPTS", 2, min_value=1)
231
+ return stop_after_attempt(attempts)
232
+
233
+
234
+ def retry_predicate(policy: ErrorPolicy, **kw):
235
+ """Build a Tenacity `retry=` argument from a policy.
236
+
237
+ Example:
238
+ retry=retry_predicate(OPENAI_ERROR_POLICY, extra_non_retryable_codes=["some_code"])
239
+ """
240
+ return retry_if_exception(make_is_transient(policy, **kw))
241
+
242
+
243
+ # --------------------------
244
+ # Built-in policies
245
+ # --------------------------
246
+ OPENAI_MESSAGE_MARKERS: dict[str, tuple[str, ...]] = {
247
+ "insufficient_quota": ("insufficient_quota", "exceeded your current quota"),
248
+ }
249
+
250
+ try:
251
+ from openai import (
252
+ AuthenticationError,
253
+ RateLimitError,
254
+ APIConnectionError,
255
+ APITimeoutError,
256
+ APIStatusError,
257
+ )
258
+
259
+ OPENAI_ERROR_POLICY = ErrorPolicy(
260
+ auth_excs=(AuthenticationError,),
261
+ rate_limit_excs=(RateLimitError,),
262
+ network_excs=(APIConnectionError, APITimeoutError),
263
+ http_excs=(APIStatusError,),
264
+ non_retryable_codes=frozenset({"insufficient_quota"}),
265
+ message_markers=OPENAI_MESSAGE_MARKERS,
266
+ )
267
+ except Exception: # pragma: no cover - OpenAI may not be installed in some envs
268
+ OPENAI_ERROR_POLICY = None
269
+
270
+
271
+ __all__ = [
272
+ "ErrorPolicy",
273
+ "extract_error_code",
274
+ "make_is_transient",
275
+ "default_wait",
276
+ "default_stop",
277
+ "retry_predicate",
278
+ "OPENAI_MESSAGE_MARKERS",
279
+ "OPENAI_ERROR_POLICY",
280
+ ]
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass, field, replace
4
- from typing import Any, Optional, Awaitable, Callable
4
+ from typing import Any, Optional, Awaitable, Callable, Generic, TypeVar
5
5
 
6
6
  from deepeval.tracing import observe
7
7
  from deepeval.prompt import Prompt
@@ -14,6 +14,8 @@ except Exception as e:
14
14
  "openai-agents is required for this integration. Please install it."
15
15
  ) from e
16
16
 
17
+ TContext = TypeVar("TContext")
18
+
17
19
 
18
20
  class _ObservedModel(Model):
19
21
  def __init__(
@@ -153,7 +155,7 @@ class _ObservedProvider(ModelProvider):
153
155
 
154
156
 
155
157
  @dataclass
156
- class DeepEvalAgent(BaseAgent[Any]):
158
+ class DeepEvalAgent(BaseAgent[TContext], Generic[TContext]):
157
159
  """
158
160
  A subclass of agents.Agent that accepts `metrics` and `metric_collection`
159
161
  and ensures the underlying model's `get_response` is wrapped with deepeval.observe.
@@ -1,47 +1,72 @@
1
- from typing import Optional, List, Dict, Union, Type
2
1
  import os
3
2
 
3
+ from typing import Dict, List, Optional, Type, TYPE_CHECKING
4
+ from types import SimpleNamespace
5
+
4
6
  from deepeval.models.base_model import DeepEvalBaseEmbeddingModel
5
7
 
6
- # check langchain availability
7
- try:
8
+
9
+ if TYPE_CHECKING:
10
+ from chromadb.api.models.Collection import Collection
8
11
  from langchain_core.documents import Document as LCDocument
9
- from langchain_text_splitters import TokenTextSplitter
10
12
  from langchain_text_splitters.base import TextSplitter
11
- from langchain_community.document_loaders import (
12
- PyPDFLoader,
13
- TextLoader,
14
- Docx2txtLoader,
15
- )
16
13
  from langchain_community.document_loaders.base import BaseLoader
17
14
 
18
- langchain_available = True
19
- except ImportError:
20
- langchain_available = False
21
-
22
- # check chromadb availability
23
- try:
24
- import chromadb
25
- from chromadb import Metadata
26
- from chromadb.api.models.Collection import Collection
27
-
28
- chroma_db_available = True
29
- except ImportError:
30
- chroma_db_available = False
31
15
 
32
-
33
- # Define a helper function to check availability
34
- def _check_chromadb_available():
35
- if not chroma_db_available:
16
+ # Lazy import caches
17
+ _langchain_ns = None
18
+ _chroma_mod = None
19
+ _langchain_import_error = None
20
+ _chroma_import_error = None
21
+
22
+
23
+ def _get_langchain():
24
+ """Return a namespace of langchain classes, or raise ImportError with root cause."""
25
+ global _langchain_ns, _langchain_import_error
26
+ if _langchain_ns is not None:
27
+ return _langchain_ns
28
+ try:
29
+ from langchain_core.documents import Document as LCDocument # type: ignore
30
+ from langchain_text_splitters import TokenTextSplitter # type: ignore
31
+ from langchain_text_splitters.base import TextSplitter # type: ignore
32
+ from langchain_community.document_loaders import ( # type: ignore
33
+ PyPDFLoader,
34
+ TextLoader,
35
+ Docx2txtLoader,
36
+ )
37
+ from langchain_community.document_loaders.base import BaseLoader # type: ignore
38
+
39
+ _langchain_ns = SimpleNamespace(
40
+ LCDocument=LCDocument,
41
+ TokenTextSplitter=TokenTextSplitter,
42
+ TextSplitter=TextSplitter,
43
+ PyPDFLoader=PyPDFLoader,
44
+ TextLoader=TextLoader,
45
+ Docx2txtLoader=Docx2txtLoader,
46
+ BaseLoader=BaseLoader,
47
+ )
48
+ return _langchain_ns
49
+ except Exception as e:
50
+ _langchain_import_error = e
36
51
  raise ImportError(
37
- "chromadb is required for this functionality. Install it via your package manager"
52
+ f"langchain, langchain_community, and langchain_text_splitters are required. Root cause: {e}"
38
53
  )
39
54
 
40
55
 
41
- def _check_langchain_available():
42
- if not langchain_available:
56
+ def _get_chromadb():
57
+ """Return the chromadb module, or raise ImportError with root cause."""
58
+ global _chroma_mod, _chroma_import_error
59
+ if _chroma_mod is not None:
60
+ return _chroma_mod
61
+ try:
62
+ import chromadb
63
+
64
+ _chroma_mod = chromadb
65
+ return _chroma_mod
66
+ except Exception as e:
67
+ _chroma_import_error = e
43
68
  raise ImportError(
44
- "langchain, langchain_community, and langchain_text_splitters are required for this functionality. Install it via your package manager"
69
+ f"chromadb is required for this functionality. Root cause: {e}"
45
70
  )
46
71
 
47
72
 
@@ -50,22 +75,16 @@ class DocumentChunker:
50
75
  self,
51
76
  embedder: DeepEvalBaseEmbeddingModel,
52
77
  ):
53
- _check_chromadb_available()
54
- _check_langchain_available()
55
78
  self.text_token_count: Optional[int] = None # set later
56
79
 
57
80
  self.source_file: Optional[str] = None
58
81
  self.chunks: Optional["Collection"] = None
59
- self.sections: Optional[List[LCDocument]] = None
82
+ self.sections: Optional[List["LCDocument"]] = None
60
83
  self.embedder: DeepEvalBaseEmbeddingModel = embedder
61
84
  self.mean_embedding: Optional[float] = None
62
85
 
63
86
  # Mapping of file extensions to their respective loader classes
64
- self.loader_mapping: Dict[str, Type[BaseLoader]] = {
65
- ".pdf": PyPDFLoader,
66
- ".txt": TextLoader,
67
- ".docx": Docx2txtLoader,
68
- }
87
+ self.loader_mapping: Dict[str, "Type[BaseLoader]"] = {}
69
88
 
70
89
  #########################################################
71
90
  ### Chunking Docs #######################################
@@ -74,7 +93,8 @@ class DocumentChunker:
74
93
  async def a_chunk_doc(
75
94
  self, chunk_size: int = 1024, chunk_overlap: int = 0
76
95
  ) -> "Collection":
77
- _check_chromadb_available()
96
+ lc = _get_langchain()
97
+ chroma = _get_chromadb()
78
98
 
79
99
  # Raise error if chunk_doc is called before load_doc
80
100
  if self.sections is None or self.source_file is None:
@@ -85,13 +105,13 @@ class DocumentChunker:
85
105
  # Create ChromaDB client
86
106
  full_document_path, _ = os.path.splitext(self.source_file)
87
107
  document_name = os.path.basename(full_document_path)
88
- client = chromadb.PersistentClient(path=f".vector_db/{document_name}")
108
+ client = chroma.PersistentClient(path=f".vector_db/{document_name}")
89
109
 
90
110
  collection_name = f"processed_chunks_{chunk_size}_{chunk_overlap}"
91
111
  try:
92
112
  collection = client.get_collection(name=collection_name)
93
113
  except Exception:
94
- text_splitter: TextSplitter = TokenTextSplitter(
114
+ text_splitter: "TextSplitter" = lc.TokenTextSplitter(
95
115
  chunk_size=chunk_size, chunk_overlap=chunk_overlap
96
116
  )
97
117
  # Collection doesn't exist, so create it and then add documents
@@ -108,7 +128,7 @@ class DocumentChunker:
108
128
  batch_contents = contents[i:batch_end]
109
129
  batch_embeddings = embeddings[i:batch_end]
110
130
  batch_ids = ids[i:batch_end]
111
- batch_metadatas: List["Metadata"] = [
131
+ batch_metadatas: List[dict] = [
112
132
  {"source_file": self.source_file} for _ in batch_contents
113
133
  ]
114
134
 
@@ -121,7 +141,8 @@ class DocumentChunker:
121
141
  return collection
122
142
 
123
143
  def chunk_doc(self, chunk_size: int = 1024, chunk_overlap: int = 0):
124
- _check_chromadb_available()
144
+ lc = _get_langchain()
145
+ chroma = _get_chromadb()
125
146
 
126
147
  # Raise error if chunk_doc is called before load_doc
127
148
  if self.sections is None or self.source_file is None:
@@ -132,13 +153,13 @@ class DocumentChunker:
132
153
  # Create ChromaDB client
133
154
  full_document_path, _ = os.path.splitext(self.source_file)
134
155
  document_name = os.path.basename(full_document_path)
135
- client = chromadb.PersistentClient(path=f".vector_db/{document_name}")
156
+ client = chroma.PersistentClient(path=f".vector_db/{document_name}")
136
157
 
137
158
  collection_name = f"processed_chunks_{chunk_size}_{chunk_overlap}"
138
159
  try:
139
160
  collection = client.get_collection(name=collection_name)
140
161
  except Exception:
141
- text_splitter: TextSplitter = TokenTextSplitter(
162
+ text_splitter: "TextSplitter" = lc.TokenTextSplitter(
142
163
  chunk_size=chunk_size, chunk_overlap=chunk_overlap
143
164
  )
144
165
  # Collection doesn't exist, so create it and then add documents
@@ -155,7 +176,7 @@ class DocumentChunker:
155
176
  batch_contents = contents[i:batch_end]
156
177
  batch_embeddings = embeddings[i:batch_end]
157
178
  batch_ids = ids[i:batch_end]
158
- batch_metadatas: List["Metadata"] = [
179
+ batch_metadatas: List[dict] = [
159
180
  {"source_file": self.source_file} for _ in batch_contents
160
181
  ]
161
182
 
@@ -172,17 +193,31 @@ class DocumentChunker:
172
193
  #########################################################
173
194
 
174
195
  def get_loader(self, path: str, encoding: Optional[str]) -> "BaseLoader":
196
+ lc = _get_langchain()
197
+ # set mapping lazily now that langchain classes exist
198
+ if not self.loader_mapping:
199
+ self.loader_mapping = {
200
+ ".pdf": lc.PyPDFLoader,
201
+ ".txt": lc.TextLoader,
202
+ ".docx": lc.Docx2txtLoader,
203
+ ".md": lc.TextLoader,
204
+ ".markdown": lc.TextLoader,
205
+ ".mdx": lc.TextLoader,
206
+ }
207
+
175
208
  # Find appropriate doc loader
176
209
  _, extension = os.path.splitext(path)
177
210
  extension = extension.lower()
178
- loader: Optional[type[BaseLoader]] = self.loader_mapping.get(extension)
211
+ loader: Optional["Type[BaseLoader]"] = self.loader_mapping.get(
212
+ extension
213
+ )
179
214
  if loader is None:
180
215
  raise ValueError(f"Unsupported file format: {extension}")
181
216
 
182
- # Load doc into sections and calculate total character count
183
- if loader is TextLoader:
217
+ # Load doc into sections and calculate total token count
218
+ if loader is lc.TextLoader:
184
219
  return loader(path, encoding=encoding, autodetect_encoding=True)
185
- elif loader is PyPDFLoader or loader is Docx2txtLoader:
220
+ elif loader in (lc.PyPDFLoader, lc.Docx2txtLoader):
186
221
  return loader(path)
187
222
  else:
188
223
  raise ValueError(f"Unsupported file format: {extension}")
@@ -200,5 +235,6 @@ class DocumentChunker:
200
235
  self.source_file = path
201
236
 
202
237
  def count_tokens(self, chunks: List["LCDocument"]):
203
- counter = TokenTextSplitter(chunk_size=1, chunk_overlap=0)
238
+ lc = _get_langchain()
239
+ counter = lc.TokenTextSplitter(chunk_size=1, chunk_overlap=0)
204
240
  return len(counter.split_documents(chunks))