versionhq 1.1.10.8__py3-none-any.whl → 1.1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. versionhq/__init__.py +1 -1
  2. versionhq/_utils/vars.py +2 -0
  3. versionhq/agent/TEMPLATES/Backstory.py +2 -2
  4. versionhq/agent/default_agents.py +15 -0
  5. versionhq/agent/model.py +127 -39
  6. versionhq/agent/parser.py +3 -20
  7. versionhq/{_utils → agent}/rpm_controller.py +22 -15
  8. versionhq/knowledge/__init__.py +0 -0
  9. versionhq/knowledge/_utils.py +11 -0
  10. versionhq/knowledge/embedding.py +192 -0
  11. versionhq/knowledge/model.py +54 -0
  12. versionhq/knowledge/source.py +413 -0
  13. versionhq/knowledge/source_docling.py +129 -0
  14. versionhq/knowledge/storage.py +177 -0
  15. versionhq/llm/model.py +76 -62
  16. versionhq/memory/__init__.py +0 -0
  17. versionhq/memory/contextual_memory.py +96 -0
  18. versionhq/memory/model.py +174 -0
  19. versionhq/storage/base.py +14 -0
  20. versionhq/storage/ltm_sqlite_storage.py +131 -0
  21. versionhq/storage/mem0_storage.py +109 -0
  22. versionhq/storage/rag_storage.py +231 -0
  23. versionhq/storage/task_output_storage.py +18 -29
  24. versionhq/storage/utils.py +26 -0
  25. versionhq/task/TEMPLATES/Description.py +5 -0
  26. versionhq/task/evaluate.py +122 -0
  27. versionhq/task/model.py +134 -43
  28. versionhq/team/team_planner.py +1 -1
  29. versionhq/tool/model.py +44 -46
  30. {versionhq-1.1.10.8.dist-info → versionhq-1.1.11.0.dist-info}/METADATA +54 -40
  31. versionhq-1.1.11.0.dist-info/RECORD +64 -0
  32. versionhq-1.1.10.8.dist-info/RECORD +0 -45
  33. {versionhq-1.1.10.8.dist-info → versionhq-1.1.11.0.dist-info}/LICENSE +0 -0
  34. {versionhq-1.1.10.8.dist-info → versionhq-1.1.11.0.dist-info}/WHEEL +0 -0
  35. {versionhq-1.1.10.8.dist-info → versionhq-1.1.11.0.dist-info}/top_level.txt +0 -0
versionhq/__init__.py CHANGED
@@ -18,7 +18,7 @@ from versionhq.tool.model import Tool
18
18
  from versionhq.tool.composio_tool import ComposioHandler
19
19
 
20
20
 
21
- __version__ = "1.1.10.8"
21
+ __version__ = "1.1.11.0"
22
22
  __all__ = [
23
23
  "Agent",
24
24
  "Customer",
@@ -0,0 +1,2 @@
1
+ KNOWLEDGE_DIRECTORY="knowledge"
2
+ MAX_FILE_NAME_LENGTH=255
@@ -1,4 +1,4 @@
1
- BACKSTORY_FULL="""You are an expert {role} with deep understanding of {knowledge} and highly skilled in {skillsets}. You have abilities to call the RAG tools that can {rag_tool_overview}. Your primary goal is to identify competitive solutions by leveraging your knowledge and skillsets to achieve the following goal: {goal}."""
1
+ BACKSTORY_FULL="""You are an expert {role} highly skilled in {skills}. You have abilities to query relevant information from the given knowledge sources and use tools such as {tools}. Leveraging these, you will identify competitive solutions to achieve the following goal: {goal}."""
2
2
 
3
3
 
4
- BACKSTORY_SHORT="""You are an expert {role} with the right skillsets and knowledge. Your primary goal is to identify competitive solutions by leveraging your knowledge and skillsets to achieve the following goal: {goal}."""
4
+ BACKSTORY_SHORT="""You are an expert {role} with relevant skillsets and abilities to query relevant information from the given knowledge sources. Leveraging these, you will identify competitive solutions to achieve the following goal: {goal}."""
@@ -0,0 +1,15 @@
1
+ from versionhq.agent.model import Agent
2
+ from versionhq.llm.model import DEFAULT_MODEL_NAME
3
+
4
+ """
5
+ List up agents to be called across the project.
6
+ """
7
+
8
+ client_manager = Agent(role="Client Manager", goal="communicate with clients on the task progress", llm=DEFAULT_MODEL_NAME)
9
+
10
+ task_evaluator = Agent(
11
+ role="Task Evaluator",
12
+ goal="score the output according to the given evaluation criteria.",
13
+ llm=DEFAULT_MODEL_NAME,
14
+ llm_config=dict(top_p=0.8, top_k=30, max_tokens=5000, temperature=0.9)
15
+ )
versionhq/agent/model.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  import uuid
3
+ import datetime
3
4
  from typing import Any, Dict, List, Optional, TypeVar, Callable, Type
4
5
  from typing_extensions import Self
5
6
  from dotenv import load_dotenv
@@ -10,8 +11,11 @@ from pydantic_core import PydanticCustomError
10
11
 
11
12
  from versionhq.llm.model import LLM, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MODEL_NAME
12
13
  from versionhq.tool.model import Tool, ToolSet
14
+ from versionhq.knowledge.model import BaseKnowledgeSource, Knowledge
15
+ from versionhq.memory.contextual_memory import ContextualMemory
16
+ from versionhq.memory.model import ShortTermMemory, LongTermMemory, UserMemory
13
17
  from versionhq._utils.logger import Logger
14
- from versionhq._utils.rpm_controller import RPMController
18
+ from versionhq.agent.rpm_controller import RPMController
15
19
  from versionhq._utils.usage_metrics import UsageMetrics
16
20
  from versionhq._utils.process_config import process_config
17
21
 
@@ -20,9 +24,6 @@ load_dotenv(override=True)
20
24
  T = TypeVar("T", bound="Agent")
21
25
 
22
26
 
23
- # def _format_answer(agent, answer: str) -> AgentAction | AgentFinish:
24
- # return AgentParser(agent=agent).parse(answer)
25
-
26
27
  # def mock_agent_ops_provider():
27
28
  # def track_agent(*args, **kwargs):
28
29
  # def noop(f):
@@ -94,9 +95,20 @@ class Agent(BaseModel):
94
95
  role: str = Field(description="role of the agent - used in summary and logs")
95
96
  goal: str = Field(description="concise goal of the agent (details are set in the Task instance)")
96
97
  backstory: Optional[str] = Field(default=None, description="developer prompt to the llm")
97
- knowledge: Optional[str] = Field(default=None, description="external knowledge fed to the agent")
98
98
  skillsets: Optional[List[str]] = Field(default_factory=list)
99
- tools: Optional[List[Tool | ToolSet | Type[Tool]]] = Field(default_factory=list)
99
+ tools: Optional[List[InstanceOf[Tool | ToolSet] | Type[Tool] | Any]] = Field(default_factory=list)
100
+
101
+ # knowledge
102
+ knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(default=None)
103
+ _knowledge: Optional[Knowledge] = PrivateAttr(default=None)
104
+
105
+ # memory
106
+ use_memory: bool = Field(default=False, description="whether to store/use memory when executing the task")
107
+ memory_config: Optional[Dict[str, Any]] = Field(default=None, description="configuration for the memory. need to store user_id for UserMemory")
108
+ short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(default=None)
109
+ long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field(default=None)
110
+ user_memory: Optional[InstanceOf[UserMemory]] = Field(default=None)
111
+ embedder_config: Optional[Dict[str, Any]] = Field(default=None, description="embedder configuration for the agent's knowledge")
100
112
 
101
113
  # prompting
102
114
  use_developer_prompt: Optional[bool] = Field(default=True, description="Use developer prompt when calling the llm")
@@ -116,10 +128,10 @@ class Agent(BaseModel):
116
128
  respect_context_window: bool = Field(default=True,description="Keep messages under the context window size by summarizing content")
117
129
  max_tokens: Optional[int] = Field(default=None, description="max. number of tokens for the agent's execution")
118
130
  max_execution_time: Optional[int] = Field(default=None, description="max. execution time for an agent to execute a task")
119
- max_rpm: Optional[int] = Field(default=None, description="max. number of requests per minute for the agent execution")
131
+ max_rpm: Optional[int] = Field(default=None, description="max. number of requests per minute")
120
132
  llm_config: Optional[Dict[str, Any]] = Field(default=None, description="other llm config cascaded to the model")
121
133
 
122
- # config, cache, error handling
134
+ # cache, error, ops handling
123
135
  formatting_errors: int = Field(default=0, description="number of formatting errors.")
124
136
  agent_ops_agent_name: str = None
125
137
  agent_ops_agent_id: str = None
@@ -132,15 +144,6 @@ class Agent(BaseModel):
132
144
  raise PydanticCustomError("may_not_set_field", "This field is not to be set by the user.", {})
133
145
 
134
146
 
135
- # @field_validator(mode="before")
136
- # def set_up_from_config(cls) -> None:
137
- # if cls.config is not None:
138
- # try:
139
- # for k, v in cls.config.items():
140
- # setattr(cls, k, v)
141
- # except:
142
- # pass
143
-
144
147
  @model_validator(mode="before")
145
148
  @classmethod
146
149
  def process_model_config(cls, values: Dict[str, Any]) -> None:
@@ -165,7 +168,6 @@ class Agent(BaseModel):
165
168
  """
166
169
 
167
170
  self.agent_ops_agent_name = self.role
168
- # unaccepted_attributes = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
169
171
 
170
172
  if isinstance(self.llm, LLM):
171
173
  llm = self._set_llm_params(self.llm)
@@ -314,29 +316,82 @@ class Agent(BaseModel):
314
316
  if self.backstory is None:
315
317
  from versionhq.agent.TEMPLATES.Backstory import BACKSTORY_FULL, BACKSTORY_SHORT
316
318
  backstory = ""
319
+ skills = ", ".join([item for item in self.skillsets]) if self.skillsets else ""
320
+ tools = ", ".join([item.name for item in self.tools if hasattr(item, "name")]) if self.tools else ""
321
+ role = self.role.lower()
322
+ goal = self.goal.lower()
317
323
 
318
- if self.tools or self.knowledge or self.skillsets:
319
- backstory = BACKSTORY_FULL.format(
320
- role=self.role,
321
- goal=self.goal,
322
- knowledge=self.knowledge if isinstance(self.knowledge, str) else None,
323
- skillsets=", ".join([item for item in self.skillsets]),
324
- rag_tool_overview=", ".join([item.name for item in self.tools if hasattr(item, "name")]) if self.tools else "",
325
- )
324
+ if self.tools or self.skillsets:
325
+ backstory = BACKSTORY_FULL.format(role=role, goal=goal, skills=skills, tools=tools)
326
326
  else:
327
- backstory = BACKSTORY_SHORT.format(role=self.role, goal=self.goal)
327
+ backstory = BACKSTORY_SHORT.format(role=role, goal=goal)
328
328
 
329
329
  self.backstory = backstory
330
330
 
331
331
  return self
332
332
 
333
333
 
334
+ @model_validator(mode="after")
335
+ def set_up_rpm(self) -> Self:
336
+ """
337
+ Set up RPM controller.
338
+ """
339
+ if self.max_rpm:
340
+ self._rpm_controller = RPMController(max_rpm=self.max_rpm, _current_rpm=0)
341
+
342
+ return self
343
+
344
+
345
+ @model_validator(mode="after")
346
+ def set_up_knowledge(self) -> Self:
347
+ if self.knowledge_sources:
348
+ knowledge_agent_name = f"{self.role.replace(' ', '_')}"
349
+
350
+ if isinstance(self.knowledge_sources, list) and all(isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources):
351
+ self._knowledge = Knowledge(
352
+ sources=self.knowledge_sources,
353
+ embedder_config=self.embedder_config,
354
+ collection_name=knowledge_agent_name,
355
+ )
356
+
357
+ return self
358
+
359
+
360
+ @model_validator(mode="after")
361
+ def set_up_memory(self) -> Self:
362
+ """
363
+ Set up memories: stm, um
364
+ """
365
+
366
+ if self.use_memory == True:
367
+ self.long_term_memory = self.long_term_memory if self.long_term_memory else LongTermMemory()
368
+ self.short_term_memory = self.short_term_memory if self.short_term_memory else ShortTermMemory(agent=self, embedder_config=self.embedder_config)
369
+
370
+ if hasattr(self, "memory_config") and self.memory_config is not None:
371
+ user_id = self.memory_config.get("user_id", None)
372
+ if user_id:
373
+ self.user_memory = self.user_memory if self.user_memory else UserMemory(agent=self, user_id=user_id)
374
+ else:
375
+ self.user_memory = None
376
+
377
+ return self
378
+
379
+
380
+ def _train(self) -> Self:
381
+ """
382
+ Fine-tuned the base model using OpenAI train framework.
383
+ """
384
+ if not isinstance(self.llm, LLM):
385
+ pass
386
+
387
+
334
388
  def invoke(
335
389
  self,
336
390
  prompts: str,
337
391
  response_format: Optional[Dict[str, Any]] = None,
338
- tools: Optional[List[Tool | ToolSet | Type[Tool]]] = None,
339
- tool_res_as_final: bool = False
392
+ tools: Optional[List[InstanceOf[Tool]| InstanceOf[ToolSet] | Type[Tool]]] = None,
393
+ tool_res_as_final: bool = False,
394
+ task: Any = None
340
395
  ) -> Dict[str, Any]:
341
396
  """
342
397
  Create formatted prompts using the developer prompt and the agent's backstory, then call the base model.
@@ -347,40 +402,51 @@ class Agent(BaseModel):
347
402
  task_execution_counter = 0
348
403
  iterations = 0
349
404
  raw_response = None
350
- messages = []
351
405
 
352
- messages.append({"role": "user", "content": prompts})
406
+ messages = []
407
+ messages.append({ "role": "user", "content": prompts })
353
408
  if self.use_developer_prompt:
354
- messages.append({"role": "system", "content": self.backstory})
355
- self._logger.log(level="info", message=f"Messages sent to the model: {messages}", color="blue")
409
+ messages.append({ "role": "system", "content": self.backstory })
356
410
 
357
411
  try:
358
- if tool_res_as_final is True:
412
+ if self._rpm_controller and self.max_rpm:
413
+ self._rpm_controller.check_or_wait()
414
+
415
+ self._logger.log(level="info", message=f"Messages sent to the model: {messages}", color="blue")
416
+
417
+ if tool_res_as_final:
359
418
  func_llm = self.function_calling_llm if self.function_calling_llm and self.function_calling_llm._supports_function_calling() else LLM(model=DEFAULT_MODEL_NAME)
360
419
  raw_response = func_llm.call(messages=messages, tools=tools, tool_res_as_final=True)
420
+ task.tokens = func_llm._tokens
361
421
  else:
362
422
  raw_response = self.llm.call(messages=messages, response_format=response_format, tools=tools)
423
+ task.tokens = self.llm._tokens
363
424
 
364
425
  task_execution_counter += 1
365
426
  self._logger.log(level="info", message=f"Agent response: {raw_response}", color="blue")
366
-
427
+ return raw_response
367
428
 
368
429
  except Exception as e:
369
430
  self._logger.log(level="error", message=f"An error occured. The agent will retry: {str(e)}", color="red")
370
431
 
371
- while not raw_response and task_execution_counter < self.max_retry_limit:
372
- while not raw_response and iterations < self.maxit:
432
+ while not raw_response and task_execution_counter <= self.max_retry_limit:
433
+ while (not raw_response or raw_response == "" or raw_response is None) and iterations < self.maxit:
434
+ if self.max_rpm and self._rpm_controller:
435
+ self._rpm_controller.check_or_wait()
436
+
373
437
  raw_response = self.llm.call(messages=messages, response_format=response_format, tools=tools)
438
+ task.tokens = self.llm._tokens
374
439
  iterations += 1
375
440
 
376
441
  task_execution_counter += 1
377
442
  self._logger.log(level="info", message=f"Agent #{task_execution_counter} response: {raw_response}", color="blue")
443
+ return raw_response
378
444
 
379
445
  if not raw_response:
380
446
  self._logger.log(level="error", message="Received None or empty response from the model", color="red")
381
447
  raise ValueError("Invalid response from LLM call - None or empty.")
382
448
 
383
- return raw_response
449
+
384
450
 
385
451
 
386
452
  def execute_task(self, task, context: Optional[str] = None, task_tools: Optional[List[Tool | ToolSet]] = list()) -> str:
@@ -390,14 +456,35 @@ class Agent(BaseModel):
390
456
  The agent must consider the context to excute the task as well when it is given.
391
457
  """
392
458
  from versionhq.task.model import Task
459
+ from versionhq.knowledge._utils import extract_knowledge_context
393
460
 
394
461
  task: InstanceOf[Task] = task
395
- tools: Optional[List[Tool | ToolSet | Type[Tool]]] = task_tools + self.tools if task.can_use_agent_tools else task_tools
462
+ tools: Optional[List[InstanceOf[Tool]| InstanceOf[ToolSet] | Type[Tool]]] = task_tools + self.tools if task.can_use_agent_tools else task_tools
463
+
464
+ if self.max_rpm and self._rpm_controller:
465
+ self._rpm_controller._reset_request_count()
396
466
 
397
467
  task_prompt = task.prompt(model_provider=self.llm.provider)
398
468
  if context is not task.prompt_context:
399
469
  task_prompt += context
400
470
 
471
+ if self._knowledge:
472
+ agent_knowledge = self._knowledge.query(query=[task_prompt,])
473
+ if agent_knowledge:
474
+ agent_knowledge_context = extract_knowledge_context(knowledge_snippets=agent_knowledge)
475
+ if agent_knowledge_context:
476
+ task_prompt += agent_knowledge_context
477
+
478
+
479
+ if self.use_memory == True:
480
+ contextual_memory = ContextualMemory(
481
+ memory_config=self.memory_config, stm=self.short_term_memory, ltm=self.long_term_memory, um=self.user_memory
482
+ )
483
+ memory = contextual_memory.build_context_for_task(task=task, context=context)
484
+ if memory.strip() != "":
485
+ task_prompt += memory.strip()
486
+
487
+
401
488
  # if self.team and self.team._train:
402
489
  # task_prompt = self._training_handler(task_prompt=task_prompt)
403
490
  # else:
@@ -410,6 +497,7 @@ class Agent(BaseModel):
410
497
  response_format=task._structure_response_format(model_provider=self.llm.provider),
411
498
  tools=tools,
412
499
  tool_res_as_final=task.tool_res_as_final,
500
+ task=task
413
501
  )
414
502
 
415
503
  except Exception as e:
versionhq/agent/parser.py CHANGED
@@ -44,26 +44,6 @@ class OutputParserException(Exception):
44
44
 
45
45
 
46
46
  class AgentParser:
47
- """
48
- Parses ReAct-style LLM calls that have a single tool input.
49
-
50
- Expects output to be in one of two formats.
51
-
52
- If the output signals that an action should be taken,
53
- should be in the below format. This will result in an AgentAction
54
- being returned.
55
-
56
- Thought: agent thought here
57
- Action: search
58
- Action Input: what is the temperature in SF?
59
-
60
- If the output signals that a final answer should be given,
61
- should be in the below format. This will result in an AgentFinish
62
- being returned.
63
-
64
- Thought: agent thought here
65
- Final Answer: The temperature is 100 degrees
66
- """
67
47
 
68
48
  agent: Any = None
69
49
 
@@ -111,6 +91,7 @@ class AgentParser:
111
91
  # self.agent.increment_formatting_errors()
112
92
  raise OutputParserException(error)
113
93
 
94
+
114
95
  def _extract_thought(self, text: str) -> str:
115
96
  regex = r"(.*?)(?:\n\nAction|\n\nFinal Answer)"
116
97
  thought_match = re.search(regex, text, re.DOTALL)
@@ -118,10 +99,12 @@ class AgentParser:
118
99
  return thought_match.group(1).strip()
119
100
  return ""
120
101
 
102
+
121
103
  def _clean_action(self, text: str) -> str:
122
104
  """Clean action string by removing non-essential formatting characters."""
123
105
  return re.sub(r"^\s*\*+\s*|\s*\*+\s*$", "", text).strip()
124
106
 
107
+
125
108
  def _safe_repair_json(self, tool_input: str) -> str:
126
109
  UNABLE_TO_REPAIR_JSON_RESULTS = ['""', "{}"]
127
110
 
@@ -10,7 +10,7 @@ from versionhq._utils.logger import Logger
10
10
 
11
11
  class RPMController(BaseModel):
12
12
  max_rpm: Optional[int] = Field(default=None)
13
- logger: Logger = Field(default_factory=lambda: Logger(verbose=False))
13
+ _logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=True))
14
14
  _current_rpm: int = PrivateAttr(default=0)
15
15
  _timer: Optional[threading.Timer] = PrivateAttr(default=None)
16
16
  _lock: Optional[threading.Lock] = PrivateAttr(default=None)
@@ -24,38 +24,45 @@ class RPMController(BaseModel):
24
24
  self._reset_request_count()
25
25
  return self
26
26
 
27
- def check_or_wait(self) -> bool:
27
+
28
+ def _check_and_increment(self) -> bool:
28
29
  if self.max_rpm is None:
29
30
  return True
30
31
 
31
- def _check_and_increment():
32
- if self.max_rpm is not None and self._current_rpm < self.max_rpm:
33
- self._current_rpm += 1
34
- return True
35
- elif self.max_rpm is not None:
36
- self.logger.log(
37
- "info", "Max RPM reached, waiting for next minute to start."
38
- )
39
- self._wait_for_next_minute()
40
- self._current_rpm = 1
41
- return True
32
+ elif self.max_rpm is not None and self._current_rpm < self.max_rpm:
33
+ self._current_rpm += 1
34
+ return True
35
+
36
+ elif self.max_rpm is not None and self._current_rpm >= self.max_rpm:
37
+ self._logger.log(level="info", message="Max RPM reached, waiting for next minute to start.", color="yellow")
38
+ self._wait_for_next_minute()
39
+ self._current_rpm = 1 # restart
42
40
  return True
43
41
 
42
+ else:
43
+ return False
44
+
45
+ def check_or_wait(self) -> bool:
44
46
  if self._lock:
45
47
  with self._lock:
46
- return _check_and_increment()
48
+ return self._check_and_increment()
47
49
  else:
48
- return _check_and_increment()
50
+ return self._check_and_increment()
51
+
52
+ return False
53
+
49
54
 
50
55
  def stop_rpm_counter(self) -> None:
51
56
  if self._timer:
52
57
  self._timer.cancel()
53
58
  self._timer = None
54
59
 
60
+
55
61
  def _wait_for_next_minute(self) -> None:
56
62
  time.sleep(60)
57
63
  self._current_rpm = 0
58
64
 
65
+
59
66
  def _reset_request_count(self) -> None:
60
67
  def _reset():
61
68
  self._current_rpm = 0
File without changes
@@ -0,0 +1,11 @@
1
+ from typing import List, Dict, Any
2
+
3
+
4
+ def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
5
+ """
6
+ Extract knowledge from the task prompt.
7
+ """
8
+
9
+ valid_snippets = [result["context"] for result in knowledge_snippets if result and result.get("context")]
10
+ snippet = "\n".join(valid_snippets)
11
+ return f"Additional Information: {snippet}" if valid_snippets else ""
@@ -0,0 +1,192 @@
1
+ import os
2
+ from typing import Any, Dict, cast
3
+
4
+ from chromadb import Documents, EmbeddingFunction, Embeddings
5
+ from chromadb.api.types import validate_embedding_function
6
+
7
+ from versionhq._utils.logger import Logger
8
+
9
+
10
+ class EmbeddingConfigurator:
11
+ def __init__(self):
12
+ self.embedding_functions = {
13
+ "openai": self._configure_openai,
14
+ "azure": self._configure_azure,
15
+ "ollama": self._configure_ollama,
16
+ "vertexai": self._configure_vertexai,
17
+ "google": self._configure_google,
18
+ "cohere": self._configure_cohere,
19
+ "voyageai": self._configure_voyageai,
20
+ "bedrock": self._configure_bedrock,
21
+ "huggingface": self._configure_huggingface,
22
+ "watson": self._configure_watson,
23
+ }
24
+
25
+ def configure_embedder(self, embedder_config: Dict[str, Any] | None = None) -> EmbeddingFunction:
26
+ """
27
+ Configures and returns an embedding function based on the provided config.
28
+ """
29
+
30
+ if embedder_config is None:
31
+ return self._create_default_embedding_function()
32
+
33
+ provider = embedder_config.get("provider")
34
+ config = embedder_config.get("config", {})
35
+ model_name = config.get("model")
36
+
37
+ if isinstance(provider, EmbeddingFunction):
38
+ try:
39
+ validate_embedding_function(provider)
40
+ return provider
41
+ except Exception as e:
42
+ raise ValueError(f"Invalid custom embedding function: {str(e)}")
43
+
44
+ if provider not in self.embedding_functions:
45
+ raise Exception(
46
+ f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
47
+ )
48
+
49
+ return self.embedding_functions[provider](config, model_name)
50
+
51
+ @staticmethod
52
+ def _create_default_embedding_function():
53
+ from chromadb.utils.embedding_functions.openai_embedding_function import (
54
+ OpenAIEmbeddingFunction,
55
+ )
56
+
57
+ return OpenAIEmbeddingFunction(
58
+ api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
59
+ )
60
+
61
+ @staticmethod
62
+ def _configure_openai(config, model_name):
63
+ from chromadb.utils.embedding_functions.openai_embedding_function import (
64
+ OpenAIEmbeddingFunction,
65
+ )
66
+
67
+ return OpenAIEmbeddingFunction(
68
+ api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
69
+ model_name=model_name,
70
+ )
71
+
72
+ @staticmethod
73
+ def _configure_azure(config, model_name):
74
+ from chromadb.utils.embedding_functions.openai_embedding_function import (
75
+ OpenAIEmbeddingFunction,
76
+ )
77
+
78
+ return OpenAIEmbeddingFunction(
79
+ api_key=config.get("api_key"),
80
+ api_base=config.get("api_base"),
81
+ api_type=config.get("api_type", "azure"),
82
+ api_version=config.get("api_version"),
83
+ model_name=model_name,
84
+ )
85
+
86
+ @staticmethod
87
+ def _configure_ollama(config, model_name):
88
+ from chromadb.utils.embedding_functions.ollama_embedding_function import (
89
+ OllamaEmbeddingFunction,
90
+ )
91
+
92
+ return OllamaEmbeddingFunction(
93
+ url=config.get("url", "http://localhost:11434/api/embeddings"),
94
+ model_name=model_name,
95
+ )
96
+
97
+ @staticmethod
98
+ def _configure_vertexai(config, model_name):
99
+ from chromadb.utils.embedding_functions.google_embedding_function import (
100
+ GoogleVertexEmbeddingFunction,
101
+ )
102
+
103
+ return GoogleVertexEmbeddingFunction(model_name=model_name, api_key=config.get("api_key"))
104
+
105
+ @staticmethod
106
+ def _configure_google(config, model_name):
107
+ from chromadb.utils.embedding_functions.google_embedding_function import (
108
+ GoogleGenerativeAiEmbeddingFunction,
109
+ )
110
+
111
+ return GoogleGenerativeAiEmbeddingFunction(model_name=model_name, api_key=config.get("api_key"))
112
+
113
+ @staticmethod
114
+ def _configure_cohere(config, model_name):
115
+ from chromadb.utils.embedding_functions.cohere_embedding_function import (
116
+ CohereEmbeddingFunction,
117
+ )
118
+
119
+ return CohereEmbeddingFunction(
120
+ model_name=model_name,
121
+ api_key=config.get("api_key"),
122
+ )
123
+
124
+ @staticmethod
125
+ def _configure_voyageai(config, model_name):
126
+ from chromadb.utils.embedding_functions.voyageai_embedding_function import (
127
+ VoyageAIEmbeddingFunction,
128
+ )
129
+
130
+ return VoyageAIEmbeddingFunction(
131
+ model_name=model_name,
132
+ api_key=config.get("api_key"),
133
+ )
134
+
135
+ @staticmethod
136
+ def _configure_bedrock(config, model_name):
137
+ from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
138
+ AmazonBedrockEmbeddingFunction,
139
+ )
140
+
141
+ return AmazonBedrockEmbeddingFunction(
142
+ session=config.get("session"),
143
+ )
144
+
145
+ @staticmethod
146
+ def _configure_huggingface(config, model_name):
147
+ from chromadb.utils.embedding_functions.huggingface_embedding_function import (
148
+ HuggingFaceEmbeddingServer,
149
+ )
150
+
151
+ return HuggingFaceEmbeddingServer(
152
+ url=config.get("api_url"),
153
+ )
154
+
155
+ @staticmethod
156
+ def _configure_watson(config, model_name):
157
+ try:
158
+ import ibm_watsonx_ai.foundation_models as watson_models
159
+ from ibm_watsonx_ai import Credentials
160
+ from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
161
+ except ImportError as e:
162
+ raise ImportError(
163
+ "IBM Watson dependencies are not installed. Please install them to use Watson embedding."
164
+ ) from e
165
+
166
+ class WatsonEmbeddingFunction(EmbeddingFunction):
167
+ def __call__(self, input: Documents) -> Embeddings:
168
+ if isinstance(input, str):
169
+ input = [input]
170
+
171
+ embed_params = {
172
+ EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
173
+ EmbedParams.RETURN_OPTIONS: {"input_text": True},
174
+ }
175
+
176
+ embedding = watson_models.Embeddings(
177
+ model_id=config.get("model"),
178
+ params=embed_params,
179
+ credentials=Credentials(
180
+ api_key=config.get("api_key"), url=config.get("api_url")
181
+ ),
182
+ project_id=config.get("project_id"),
183
+ )
184
+
185
+ try:
186
+ embeddings = embedding.embed_documents(input)
187
+ return cast(Embeddings, embeddings)
188
+ except Exception as e:
189
+ Logger(verbose=True).log(level="error", message=f"Error during Watson embedding: {str(e)}", color="red")
190
+ raise e
191
+
192
+ return WatsonEmbeddingFunction()