codeembed 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. codeembed/__init__.py +59 -0
  2. codeembed/bootstrap/__init__.py +17 -0
  3. codeembed/bootstrap/services.py +220 -0
  4. codeembed/cli.py +454 -0
  5. codeembed/config/__init__.py +5 -0
  6. codeembed/config/models.py +13 -0
  7. codeembed/cost_tracking/__init__.py +7 -0
  8. codeembed/cost_tracking/llm_wrapper.py +39 -0
  9. codeembed/cost_tracking/models.py +52 -0
  10. codeembed/delta_computer/__init__.py +5 -0
  11. codeembed/delta_computer/delta_computer.py +75 -0
  12. codeembed/doc_embedder/__init__.py +5 -0
  13. codeembed/doc_embedder/doc_embedder.py +134 -0
  14. codeembed/doc_provider/__init__.py +10 -0
  15. codeembed/doc_provider/base.py +14 -0
  16. codeembed/doc_provider/local_doc_provider.py +58 -0
  17. codeembed/doc_provider/models.py +20 -0
  18. codeembed/doc_search_service/__init__.py +5 -0
  19. codeembed/doc_search_service/doc_search_service.py +48 -0
  20. codeembed/doc_splitters/__init__.py +8 -0
  21. codeembed/doc_splitters/generic_splitter.py +165 -0
  22. codeembed/doc_splitters/models.py +14 -0
  23. codeembed/llm/__init__.py +13 -0
  24. codeembed/llm/base.py +31 -0
  25. codeembed/llm/models.py +27 -0
  26. codeembed/llm/ollama_adapter.py +64 -0
  27. codeembed/llm/openai_adapter.py +96 -0
  28. codeembed/mcp_server.py +45 -0
  29. codeembed/setup_logger.py +34 -0
  30. codeembed/utils/__init__.py +9 -0
  31. codeembed/utils/checksum_utils.py +5 -0
  32. codeembed/utils/string_utils.py +5 -0
  33. codeembed/utils/time_utils.py +5 -0
  34. codeembed/vector_db/__init__.py +9 -0
  35. codeembed/vector_db/base.py +27 -0
  36. codeembed/vector_db/chromadb_adapter.py +130 -0
  37. codeembed/vector_db/models.py +16 -0
  38. codeembed-0.1.0.dist-info/METADATA +292 -0
  39. codeembed-0.1.0.dist-info/RECORD +42 -0
  40. codeembed-0.1.0.dist-info/WHEEL +4 -0
  41. codeembed-0.1.0.dist-info/entry_points.txt +2 -0
  42. codeembed-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,134 @@
1
+ import logging
2
+ from typing import List
3
+ from uuid import uuid4
4
+
5
+ from codeembed.delta_computer.delta_computer import DeltaComputer
6
+ from codeembed.doc_provider.base import DocProviderBase
7
+ from codeembed.doc_splitters.generic_splitter import FileSplitter
8
+ from codeembed.doc_splitters.models import FileSegment
9
+ from codeembed.llm.base import LLMServiceBase
10
+ from codeembed.llm.models import ChatMessage
11
+ from codeembed.vector_db.base import VectorDbBase
12
+ from codeembed.vector_db.models import Chunk
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _segment_to_chunk(
18
+ llm_service: LLMServiceBase,
19
+ segment: FileSegment,
20
+ full_content: str,
21
+ file_path: str,
22
+ llm_model: str,
23
+ ) -> str:
24
+
25
+ # NOTE: For markdown files we could embed directly without LLM summarization.
26
+ # Just split on ## headers.
27
+
28
+ logger.info("Analyzing segment %s in file %s...", segment.content.split("\n")[0], file_path)
29
+
30
+ messages: List[ChatMessage] = [
31
+ {"role": "system", "content": "You are an expert at describing code."},
32
+ {
33
+ "role": "user",
34
+ "content": f"""In the context of the following file:
35
+ <File Path>{file_path}</File Path>
36
+ <FileContent>
37
+ {full_content}
38
+ </FileContent>
39
+ Please describe the purpose following code/text segment:
40
+ <Segment>
41
+ <Line Start>{segment.line_start}</Line Start>
42
+ <Content>
43
+ {segment.content}
44
+ </Content>
45
+ <Line End>{segment.line_end}</Line End>
46
+ </Segment>
47
+ If this is a function or class, please describe what it does and how it interacts with the application.
48
+ If it is a text paragraph, explain what it covers.
49
+ Focus on the key aspects of the text or code.
50
+ Write a succint summary.
51
+ Return the summary only without any additional comments.
52
+ Start with, e.g.,
53
+ This <segment type> is ...
54
+ """,
55
+ },
56
+ ]
57
+
58
+ result = llm_service.generate_response(messages, llm_model, max_tokens=1024, temperature=0.3)
59
+
60
+ logger.info("Generated summary for segment in file %s: %s", file_path, result.response)
61
+
62
+ return result.response
63
+
64
+
65
+ class DocEmbedder:
66
+ def __init__(
67
+ self,
68
+ doc_provider: DocProviderBase,
69
+ vector_db: VectorDbBase,
70
+ llm_service: LLMServiceBase,
71
+ llm_model: str,
72
+ debounce_seconds: int = 10,
73
+ ) -> None:
74
+ self._doc_provider = doc_provider
75
+ self._vector_db = vector_db
76
+ self._llm_service = llm_service
77
+ self._llm_model = llm_model
78
+ self._debounce_seconds = debounce_seconds
79
+
80
+ def embed_codebase(self) -> None:
81
+ """Embeds the codebase and prepares it for vector search."""
82
+
83
+ logger.info("Computing deltas...")
84
+
85
+ chunks_ids_to_remove, files_to_update = DeltaComputer(
86
+ self._doc_provider, self._vector_db, self._debounce_seconds
87
+ ).compute_deltas()
88
+
89
+ logger.info(f"Detected {len(chunks_ids_to_remove)} chunks to delete from vector database.")
90
+ logger.info(f"Detected {len(files_to_update)} files to process.")
91
+
92
+ if chunks_ids_to_remove:
93
+ logger.info(f"Deleting {len(chunks_ids_to_remove)} chunks from vector database.")
94
+ self._vector_db.delete_chunks(list(chunks_ids_to_remove))
95
+
96
+ logger.info(f"Processing {len(files_to_update)} files...")
97
+
98
+ num_processed = 0
99
+ num_skipped = 0
100
+
101
+ splitter = FileSplitter()
102
+
103
+ for i, file in enumerate(files_to_update):
104
+ logger.info(f"Processing file '{file}' ({i + 1}/{len(files_to_update)})...")
105
+ doc = self._doc_provider.get_content(file)
106
+ segments = splitter.split_file(doc.content, file)
107
+ chunks = []
108
+ for segment in segments:
109
+ summary = _segment_to_chunk(self._llm_service, segment, doc.content, file, self._llm_model)
110
+ chunks.append(
111
+ Chunk(
112
+ id=uuid4(),
113
+ modified_at=doc.modified_at,
114
+ content=summary,
115
+ file_path=file,
116
+ line_start=segment.line_start,
117
+ line_end=segment.line_end,
118
+ raw_code=segment.content,
119
+ file_sha256_checksum=doc.sha256_checksum,
120
+ )
121
+ )
122
+ if not chunks:
123
+ logger.warning(f"No chunks generated for file '{file}'. Skipping embedding for this file.")
124
+ num_skipped += 1
125
+ continue
126
+ logger.info(f"Saving {len(chunks)} chunks to vector database.")
127
+ self._vector_db.add_chunks(chunks)
128
+ num_processed += 1
129
+ logger.info(f"Successfully embedded file: '{file}' ({i + 1}/{len(files_to_update)}).")
130
+
131
+ if num_processed > 0:
132
+ logger.info(f"Successfully embedded {num_processed} files.")
133
+ if num_skipped > 0:
134
+ logger.warning(f"Skipped processing {num_skipped} files.")
@@ -0,0 +1,10 @@
1
+ from codeembed.doc_provider.base import DocProviderBase
2
+ from codeembed.doc_provider.local_doc_provider import LocalDocProvider
3
+ from codeembed.doc_provider.models import DocumentContent, DocumentMeta
4
+
5
+ __all__ = [
6
+ "DocProviderBase",
7
+ "DocumentContent",
8
+ "DocumentMeta",
9
+ "LocalDocProvider",
10
+ ]
@@ -0,0 +1,14 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Iterator
3
+
4
+ from codeembed.doc_provider.models import DocumentContent, DocumentMeta
5
+
6
+
7
+ class DocProviderBase(ABC):
8
+ @abstractmethod
9
+ def iter(self) -> Iterator[DocumentMeta]:
10
+ """Iterates metadata of files."""
11
+
12
+ @abstractmethod
13
+ def get_content(self, file_path: str) -> DocumentContent:
14
+ """Gets the actual file content."""
@@ -0,0 +1,58 @@
1
+ import os
2
+ import subprocess
3
+ from datetime import datetime, timezone
4
+ from typing import Iterator, List
5
+
6
+ from codeembed.doc_provider.base import DocProviderBase
7
+ from codeembed.doc_provider.models import DocumentContent, DocumentMeta
8
+
9
+ _SKIP_DIRS = frozenset({"venv", ".venv", "node_modules", "dist", "build"})
10
+ _SKIP_FILES = frozenset({"__init__.py", ".env", ".env.local", "appsettings.json", "appsettings.Development.json"})
11
+
12
+
13
+ def _get_git_files(base_path: str) -> set[str]:
14
+ result = subprocess.run(
15
+ ["git", "ls-files", "--cached", "--others", "--exclude-standard"],
16
+ cwd=base_path,
17
+ capture_output=True,
18
+ text=True,
19
+ )
20
+ return set(result.stdout.splitlines())
21
+
22
+
23
+ class LocalDocProvider(DocProviderBase):
24
+ def __init__(self, base_path: str, supported_file_extensions: List[str]) -> None:
25
+ self._base_path = base_path
26
+ self._supported_file_extensions = [ext.lower().split(".")[-1] for ext in supported_file_extensions]
27
+
28
+ def iter(self) -> Iterator[DocumentMeta]:
29
+
30
+ file_paths = _get_git_files(self._base_path)
31
+
32
+ docs: List[DocumentMeta] = []
33
+ for file_path in file_paths:
34
+ ext = file_path.split(".")[-1]
35
+ if ext.lower() not in self._supported_file_extensions:
36
+ continue
37
+
38
+ parts = file_path.split("/")
39
+ if parts[-1] in _SKIP_FILES or any(d in _SKIP_DIRS for d in parts[:-1]):
40
+ continue
41
+
42
+ try:
43
+ modified_ts = os.path.getmtime(file_path)
44
+ modified_at = datetime.fromtimestamp(modified_ts, tz=timezone.utc)
45
+ except OSError:
46
+ continue
47
+
48
+ docs.append(DocumentMeta(file_path=file_path, modified_at=modified_at))
49
+
50
+ docs.sort(key=lambda d: d.modified_at, reverse=True)
51
+ yield from docs
52
+
53
+ def get_content(self, file_path: str) -> DocumentContent:
54
+ with open(file_path, "r", encoding="utf-8") as f:
55
+ content = f.read()
56
+ modified_ts = os.path.getmtime(file_path)
57
+ modified_at = datetime.fromtimestamp(modified_ts, tz=timezone.utc)
58
+ return DocumentContent(content=content, modified_at=modified_at)
@@ -0,0 +1,20 @@
1
+ from dataclasses import dataclass
2
+ from datetime import datetime
3
+
4
+ from codeembed.utils.checksum_utils import string_to_sha256
5
+
6
+
7
+ @dataclass
8
+ class DocumentMeta:
9
+ file_path: str
10
+ modified_at: datetime
11
+
12
+
13
+ @dataclass
14
+ class DocumentContent:
15
+ content: str
16
+ modified_at: datetime
17
+
18
+ @property
19
+ def sha256_checksum(self) -> str:
20
+ return string_to_sha256(self.content)
@@ -0,0 +1,5 @@
1
+ from codeembed.doc_search_service.doc_search_service import DocSearchService
2
+
3
+ __all__ = [
4
+ "DocSearchService",
5
+ ]
@@ -0,0 +1,48 @@
1
+ from typing import Dict, List
2
+
3
+ from codeembed.utils.string_utils import truncate_string
4
+ from codeembed.vector_db.base import VectorDbBase
5
+ from codeembed.vector_db.models import Chunk
6
+
7
+
8
+ class DocSearchService:
9
+ """
10
+ The service that searches for relevant content from vector database and formats it for LLM consumption.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ vector_db: VectorDbBase,
16
+ ) -> None:
17
+ self._vector_db = vector_db
18
+
19
+ def search(self, query: str, top_n: int = 10) -> str:
20
+ """Searches for relevant content from vector database and formats it for LLM consumption."""
21
+ chunks = self._vector_db.search(query, top_n)
22
+
23
+ chunks_by_file: Dict[str, List[Chunk]] = {}
24
+
25
+ for chunk in chunks:
26
+ if chunk.file_path not in chunks_by_file:
27
+ chunks_by_file[chunk.file_path] = []
28
+ chunks_by_file[chunk.file_path].append(chunk)
29
+
30
+ res = f"<SearchQuery>{query}</SearchQuery>\n"
31
+ res += f"<TopN>{top_n}</TopN>\n"
32
+ res += f"<Results chunkCount={len(chunks)} fileCount={len(chunks_by_file)}>\n"
33
+ for file_path, chunks in chunks_by_file.items():
34
+ res += f' <File path="{file_path}">\n'
35
+ for chunk in chunks:
36
+ # NOTE: Consider truncating by number of tokens.
37
+ raw_code = chunk.raw_code if chunk.raw_code else ""
38
+ res += " <Chunk>\n"
39
+ res += f" <Summary>\n{truncate_string(chunk.content, 4096)}\n </Summary>\n"
40
+ res += (
41
+ f' <RawCode lines="{chunk.line_start}-{chunk.line_end}">\n'
42
+ f"{truncate_string(raw_code, 4096)}\n"
43
+ f" </RawCode>\n"
44
+ )
45
+ res += " </Chunk>\n"
46
+ res += " </File>\n"
47
+ res += "</Results>\n"
48
+ return res
@@ -0,0 +1,8 @@
1
+ from codeembed.doc_splitters.generic_splitter import FileSplitter
2
+ from codeembed.doc_splitters.models import FileSegment, SplittedFile
3
+
4
+ __all__ = [
5
+ "FileSegment",
6
+ "FileSplitter",
7
+ "SplittedFile",
8
+ ]
@@ -0,0 +1,165 @@
1
+ from typing import Dict, List
2
+
3
+ import tiktoken
4
+
5
+ from codeembed.doc_splitters.models import FileSegment
6
+
7
+ _encoder = tiktoken.get_encoding("o200k_base")
8
+
9
+ _SPLIT_KEYWORDS: Dict[str, List[str]] = {
10
+ "py": ["class ", "def "],
11
+ "md": ["## "],
12
+ "ts": [
13
+ "export function ",
14
+ "export class ",
15
+ "export const ",
16
+ "export interface ",
17
+ "export type ",
18
+ "export default ",
19
+ "function ",
20
+ "class ",
21
+ ],
22
+ "tsx": [
23
+ "export function ",
24
+ "export class ",
25
+ "export const ",
26
+ "export interface ",
27
+ "export type ",
28
+ "export default ",
29
+ "function ",
30
+ "class ",
31
+ ],
32
+ "js": ["export function ", "export class ", "export const ", "export default ", "function ", "class "],
33
+ "jsx": ["export function ", "export class ", "export const ", "export default ", "function ", "class "],
34
+ }
35
+
36
+
37
+ def _count_tokens(text: str) -> int:
38
+ return len(_encoder.encode(text))
39
+
40
+
41
+ def _split_by_fixed_length(
42
+ content: str,
43
+ max_tokens: int = 512,
44
+ overlap_lines: int = 5,
45
+ line_offset: int = 0,
46
+ ) -> List[FileSegment]:
47
+ lines = content.splitlines()
48
+ chunks: List[FileSegment] = []
49
+ chunk: List[str] = []
50
+ chunk_tokens = 0
51
+ chunk_start = 0
52
+
53
+ for i, line in enumerate(lines):
54
+ line_tokens = _count_tokens(line)
55
+ if chunk_tokens + line_tokens > max_tokens and chunk:
56
+ chunks.append(
57
+ FileSegment(
58
+ line_start=line_offset + chunk_start,
59
+ line_end=line_offset + i,
60
+ content="\n".join(chunk),
61
+ )
62
+ )
63
+ overlap = chunk[-overlap_lines:]
64
+ chunk = overlap
65
+ chunk_tokens = sum(_count_tokens(ln) for ln in chunk)
66
+ chunk_start = i - len(overlap)
67
+ chunk.append(line)
68
+ chunk_tokens += line_tokens
69
+
70
+ if chunk:
71
+ chunks.append(
72
+ FileSegment(
73
+ line_start=line_offset + chunk_start,
74
+ line_end=line_offset + len(lines),
75
+ content="\n".join(chunk),
76
+ )
77
+ )
78
+
79
+ return chunks
80
+
81
+
82
+ def _detect_splits(content: str, split_keywords: List[str]) -> List[int]:
83
+ split_lines = []
84
+ for i, line in enumerate(content.splitlines()):
85
+ for keyword in split_keywords:
86
+ if line.startswith(keyword):
87
+ split_lines.append(i)
88
+ break
89
+ if not split_lines or split_lines[0] != 0:
90
+ split_lines.insert(0, 0)
91
+ return split_lines
92
+
93
+
94
+ def _apply_splits(
95
+ content: str,
96
+ split_lines: List[int],
97
+ max_tokens: int = 512,
98
+ overlap_lines: int = 5,
99
+ ) -> List[FileSegment]:
100
+ segments = []
101
+ lines = content.splitlines()
102
+
103
+ for i in range(len(split_lines)):
104
+ split_start = split_lines[i]
105
+ split_end = split_lines[i + 1] if i + 1 < len(split_lines) else len(lines)
106
+
107
+ # Scan backwards to the nearest empty line so decorators/comments are included
108
+ actual_start = split_start
109
+ for j in range(split_start - 1, -1, -1):
110
+ if not lines[j].strip():
111
+ actual_start = j + 1
112
+ break
113
+ else:
114
+ if split_start > 0:
115
+ actual_start = 0
116
+
117
+ if actual_start == split_end:
118
+ continue
119
+
120
+ segment_content = "\n".join(lines[actual_start:split_end])
121
+
122
+ if _count_tokens(segment_content) <= max_tokens:
123
+ segments.append(
124
+ FileSegment(
125
+ line_start=actual_start,
126
+ line_end=split_end,
127
+ content=segment_content,
128
+ )
129
+ )
130
+ else:
131
+ segments.extend(
132
+ _split_by_fixed_length(
133
+ segment_content,
134
+ max_tokens=max_tokens,
135
+ overlap_lines=overlap_lines,
136
+ line_offset=actual_start,
137
+ )
138
+ )
139
+
140
+ return segments
141
+
142
+
143
+ class FileSplitter:
144
+ def __init__(self, max_tokens: int = 512, overlap_lines: int = 5):
145
+ self._max_tokens = max_tokens
146
+ self._overlap_lines = overlap_lines
147
+
148
+ def split_file(self, file_content: str, file_path: str) -> List[FileSegment]:
149
+
150
+ file_extension = file_path.split(".")[-1].lower()
151
+
152
+ if file_extension not in _SPLIT_KEYWORDS:
153
+ return _split_by_fixed_length(
154
+ file_content,
155
+ max_tokens=self._max_tokens,
156
+ overlap_lines=self._overlap_lines,
157
+ )
158
+
159
+ split_lines = _detect_splits(file_content, _SPLIT_KEYWORDS[file_extension])
160
+ return _apply_splits(
161
+ file_content,
162
+ split_lines,
163
+ max_tokens=self._max_tokens,
164
+ overlap_lines=self._overlap_lines,
165
+ )
@@ -0,0 +1,14 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class FileSegment:
6
+ line_start: int
7
+ line_end: int
8
+ content: str
9
+
10
+
11
+ @dataclass
12
+ class SplittedFile:
13
+ file_path: str
14
+ full_content: str
@@ -0,0 +1,13 @@
1
+ from codeembed.llm.base import LLMServiceBase
2
+ from codeembed.llm.models import ChatMessage, LLMResponse, StructuredLLMResponse
3
+ from codeembed.llm.ollama_adapter import OllamaLLMService
4
+ from codeembed.llm.openai_adapter import OpenAILLMService
5
+
6
+ __all__ = [
7
+ "ChatMessage",
8
+ "LLMResponse",
9
+ "LLMServiceBase",
10
+ "OllamaLLMService",
11
+ "OpenAILLMService",
12
+ "StructuredLLMResponse",
13
+ ]
codeembed/llm/base.py ADDED
@@ -0,0 +1,31 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Type, TypeVar
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from codeembed.llm.models import ChatMessage, LLMResponse, StructuredLLMResponse
7
+
8
+ T = TypeVar("T", bound=BaseModel)
9
+
10
+
11
+ class LLMServiceBase(ABC):
12
+ @abstractmethod
13
+ def generate_structured_output(
14
+ self,
15
+ messages: List[ChatMessage],
16
+ llm_model: str,
17
+ output_format: Type[T],
18
+ max_tokens: Optional[int] = None,
19
+ temperature: Optional[float] = None,
20
+ ) -> StructuredLLMResponse[T]:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def generate_response(
25
+ self,
26
+ messages: List[ChatMessage],
27
+ llm_model: str,
28
+ max_tokens: Optional[int] = None,
29
+ temperature: Optional[float] = None,
30
+ ) -> LLMResponse:
31
+ pass
@@ -0,0 +1,27 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, Literal, TypedDict, TypeVar
3
+
4
+ from pydantic import BaseModel
5
+
6
+ T = TypeVar("T", bound=BaseModel)
7
+
8
+
9
+ class ChatMessage(TypedDict):
10
+ role: Literal["system", "user", "assistant"]
11
+ content: str
12
+
13
+
14
+ @dataclass
15
+ class LLMResponse:
16
+ input_tokens: int
17
+ output_tokens: int
18
+ response: str
19
+ llm_model: str
20
+
21
+
22
+ @dataclass
23
+ class StructuredLLMResponse(Generic[T]):
24
+ input_tokens: int
25
+ output_tokens: int
26
+ data: T
27
+ llm_model: str
@@ -0,0 +1,64 @@
1
+ from typing import List, Optional, Type, TypeVar
2
+
3
+ import ollama
4
+ from pydantic import BaseModel
5
+
6
+ from codeembed.llm.base import LLMServiceBase
7
+ from codeembed.llm.models import ChatMessage, LLMResponse, StructuredLLMResponse
8
+
9
+ T = TypeVar("T", bound=BaseModel)
10
+
11
+
12
+ class OllamaLLMService(LLMServiceBase):
13
+ def generate_structured_output(
14
+ self,
15
+ messages: List[ChatMessage],
16
+ llm_model: str,
17
+ output_format: Type[T],
18
+ max_tokens: Optional[int] = None,
19
+ temperature: Optional[float] = None,
20
+ ) -> StructuredLLMResponse[T]:
21
+
22
+ options = {}
23
+ if max_tokens is not None:
24
+ options["num_predict"] = max_tokens
25
+ if temperature is not None:
26
+ options["temperature"] = temperature
27
+
28
+ resp = ollama.chat(model=llm_model, messages=messages, format="json", options=options)
29
+
30
+ data = resp["message"]["content"]
31
+
32
+ model = output_format.model_validate_json(data)
33
+
34
+ return StructuredLLMResponse(
35
+ input_tokens=resp["prompt_eval_count"] or 0,
36
+ output_tokens=resp["eval_count"] or 0,
37
+ data=model,
38
+ llm_model=llm_model,
39
+ )
40
+
41
+ def generate_response(
42
+ self,
43
+ messages: List[ChatMessage],
44
+ llm_model: str,
45
+ max_tokens: Optional[int] = None,
46
+ temperature: Optional[float] = None,
47
+ ) -> LLMResponse:
48
+
49
+ options = {}
50
+ if max_tokens is not None:
51
+ options["num_predict"] = max_tokens
52
+ if temperature is not None:
53
+ options["temperature"] = temperature
54
+
55
+ resp = ollama.chat(model=llm_model, messages=messages, options=options)
56
+
57
+ content = resp["message"]["content"]
58
+
59
+ return LLMResponse(
60
+ input_tokens=resp["prompt_eval_count"] or 0,
61
+ output_tokens=resp["eval_count"] or 0,
62
+ response=content,
63
+ llm_model=llm_model,
64
+ )