codeembed 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeembed/__init__.py +59 -0
- codeembed/bootstrap/__init__.py +17 -0
- codeembed/bootstrap/services.py +220 -0
- codeembed/cli.py +454 -0
- codeembed/config/__init__.py +5 -0
- codeembed/config/models.py +13 -0
- codeembed/cost_tracking/__init__.py +7 -0
- codeembed/cost_tracking/llm_wrapper.py +39 -0
- codeembed/cost_tracking/models.py +52 -0
- codeembed/delta_computer/__init__.py +5 -0
- codeembed/delta_computer/delta_computer.py +75 -0
- codeembed/doc_embedder/__init__.py +5 -0
- codeembed/doc_embedder/doc_embedder.py +134 -0
- codeembed/doc_provider/__init__.py +10 -0
- codeembed/doc_provider/base.py +14 -0
- codeembed/doc_provider/local_doc_provider.py +58 -0
- codeembed/doc_provider/models.py +20 -0
- codeembed/doc_search_service/__init__.py +5 -0
- codeembed/doc_search_service/doc_search_service.py +48 -0
- codeembed/doc_splitters/__init__.py +8 -0
- codeembed/doc_splitters/generic_splitter.py +165 -0
- codeembed/doc_splitters/models.py +14 -0
- codeembed/llm/__init__.py +13 -0
- codeembed/llm/base.py +31 -0
- codeembed/llm/models.py +27 -0
- codeembed/llm/ollama_adapter.py +64 -0
- codeembed/llm/openai_adapter.py +96 -0
- codeembed/mcp_server.py +45 -0
- codeembed/setup_logger.py +34 -0
- codeembed/utils/__init__.py +9 -0
- codeembed/utils/checksum_utils.py +5 -0
- codeembed/utils/string_utils.py +5 -0
- codeembed/utils/time_utils.py +5 -0
- codeembed/vector_db/__init__.py +9 -0
- codeembed/vector_db/base.py +27 -0
- codeembed/vector_db/chromadb_adapter.py +130 -0
- codeembed/vector_db/models.py +16 -0
- codeembed-0.1.0.dist-info/METADATA +292 -0
- codeembed-0.1.0.dist-info/RECORD +42 -0
- codeembed-0.1.0.dist-info/WHEEL +4 -0
- codeembed-0.1.0.dist-info/entry_points.txt +2 -0
- codeembed-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
from codeembed.delta_computer.delta_computer import DeltaComputer
|
|
6
|
+
from codeembed.doc_provider.base import DocProviderBase
|
|
7
|
+
from codeembed.doc_splitters.generic_splitter import FileSplitter
|
|
8
|
+
from codeembed.doc_splitters.models import FileSegment
|
|
9
|
+
from codeembed.llm.base import LLMServiceBase
|
|
10
|
+
from codeembed.llm.models import ChatMessage
|
|
11
|
+
from codeembed.vector_db.base import VectorDbBase
|
|
12
|
+
from codeembed.vector_db.models import Chunk
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _segment_to_chunk(
|
|
18
|
+
llm_service: LLMServiceBase,
|
|
19
|
+
segment: FileSegment,
|
|
20
|
+
full_content: str,
|
|
21
|
+
file_path: str,
|
|
22
|
+
llm_model: str,
|
|
23
|
+
) -> str:
|
|
24
|
+
|
|
25
|
+
# NOTE: For markdown files we could embed directly without LLM summarization.
|
|
26
|
+
# Just split on ## headers.
|
|
27
|
+
|
|
28
|
+
logger.info("Analyzing segment %s in file %s...", segment.content.split("\n")[0], file_path)
|
|
29
|
+
|
|
30
|
+
messages: List[ChatMessage] = [
|
|
31
|
+
{"role": "system", "content": "You are an expert at describing code."},
|
|
32
|
+
{
|
|
33
|
+
"role": "user",
|
|
34
|
+
"content": f"""In the context of the following file:
|
|
35
|
+
<File Path>{file_path}</File Path>
|
|
36
|
+
<FileContent>
|
|
37
|
+
{full_content}
|
|
38
|
+
</FileContent>
|
|
39
|
+
Please describe the purpose following code/text segment:
|
|
40
|
+
<Segment>
|
|
41
|
+
<Line Start>{segment.line_start}</Line Start>
|
|
42
|
+
<Content>
|
|
43
|
+
{segment.content}
|
|
44
|
+
</Content>
|
|
45
|
+
<Line End>{segment.line_end}</Line End>
|
|
46
|
+
</Segment>
|
|
47
|
+
If this is a function or class, please describe what it does and how it interacts with the application.
|
|
48
|
+
If it is a text paragraph, explain what it covers.
|
|
49
|
+
Focus on the key aspects of the text or code.
|
|
50
|
+
Write a succint summary.
|
|
51
|
+
Return the summary only without any additional comments.
|
|
52
|
+
Start with, e.g.,
|
|
53
|
+
This <segment type> is ...
|
|
54
|
+
""",
|
|
55
|
+
},
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
result = llm_service.generate_response(messages, llm_model, max_tokens=1024, temperature=0.3)
|
|
59
|
+
|
|
60
|
+
logger.info("Generated summary for segment in file %s: %s", file_path, result.response)
|
|
61
|
+
|
|
62
|
+
return result.response
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class DocEmbedder:
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
doc_provider: DocProviderBase,
|
|
69
|
+
vector_db: VectorDbBase,
|
|
70
|
+
llm_service: LLMServiceBase,
|
|
71
|
+
llm_model: str,
|
|
72
|
+
debounce_seconds: int = 10,
|
|
73
|
+
) -> None:
|
|
74
|
+
self._doc_provider = doc_provider
|
|
75
|
+
self._vector_db = vector_db
|
|
76
|
+
self._llm_service = llm_service
|
|
77
|
+
self._llm_model = llm_model
|
|
78
|
+
self._debounce_seconds = debounce_seconds
|
|
79
|
+
|
|
80
|
+
def embed_codebase(self) -> None:
|
|
81
|
+
"""Embeds the codebase and prepares it for vector search."""
|
|
82
|
+
|
|
83
|
+
logger.info("Computing deltas...")
|
|
84
|
+
|
|
85
|
+
chunks_ids_to_remove, files_to_update = DeltaComputer(
|
|
86
|
+
self._doc_provider, self._vector_db, self._debounce_seconds
|
|
87
|
+
).compute_deltas()
|
|
88
|
+
|
|
89
|
+
logger.info(f"Detected {len(chunks_ids_to_remove)} chunks to delete from vector database.")
|
|
90
|
+
logger.info(f"Detected {len(files_to_update)} files to process.")
|
|
91
|
+
|
|
92
|
+
if chunks_ids_to_remove:
|
|
93
|
+
logger.info(f"Deleting {len(chunks_ids_to_remove)} chunks from vector database.")
|
|
94
|
+
self._vector_db.delete_chunks(list(chunks_ids_to_remove))
|
|
95
|
+
|
|
96
|
+
logger.info(f"Processing {len(files_to_update)} files...")
|
|
97
|
+
|
|
98
|
+
num_processed = 0
|
|
99
|
+
num_skipped = 0
|
|
100
|
+
|
|
101
|
+
splitter = FileSplitter()
|
|
102
|
+
|
|
103
|
+
for i, file in enumerate(files_to_update):
|
|
104
|
+
logger.info(f"Processing file '{file}' ({i + 1}/{len(files_to_update)})...")
|
|
105
|
+
doc = self._doc_provider.get_content(file)
|
|
106
|
+
segments = splitter.split_file(doc.content, file)
|
|
107
|
+
chunks = []
|
|
108
|
+
for segment in segments:
|
|
109
|
+
summary = _segment_to_chunk(self._llm_service, segment, doc.content, file, self._llm_model)
|
|
110
|
+
chunks.append(
|
|
111
|
+
Chunk(
|
|
112
|
+
id=uuid4(),
|
|
113
|
+
modified_at=doc.modified_at,
|
|
114
|
+
content=summary,
|
|
115
|
+
file_path=file,
|
|
116
|
+
line_start=segment.line_start,
|
|
117
|
+
line_end=segment.line_end,
|
|
118
|
+
raw_code=segment.content,
|
|
119
|
+
file_sha256_checksum=doc.sha256_checksum,
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
if not chunks:
|
|
123
|
+
logger.warning(f"No chunks generated for file '{file}'. Skipping embedding for this file.")
|
|
124
|
+
num_skipped += 1
|
|
125
|
+
continue
|
|
126
|
+
logger.info(f"Saving {len(chunks)} chunks to vector database.")
|
|
127
|
+
self._vector_db.add_chunks(chunks)
|
|
128
|
+
num_processed += 1
|
|
129
|
+
logger.info(f"Successfully embedded file: '{file}' ({i + 1}/{len(files_to_update)}).")
|
|
130
|
+
|
|
131
|
+
if num_processed > 0:
|
|
132
|
+
logger.info(f"Successfully embedded {num_processed} files.")
|
|
133
|
+
if num_skipped > 0:
|
|
134
|
+
logger.warning(f"Skipped processing {num_skipped} files.")
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from codeembed.doc_provider.base import DocProviderBase
|
|
2
|
+
from codeembed.doc_provider.local_doc_provider import LocalDocProvider
|
|
3
|
+
from codeembed.doc_provider.models import DocumentContent, DocumentMeta
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"DocProviderBase",
|
|
7
|
+
"DocumentContent",
|
|
8
|
+
"DocumentMeta",
|
|
9
|
+
"LocalDocProvider",
|
|
10
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Iterator
|
|
3
|
+
|
|
4
|
+
from codeembed.doc_provider.models import DocumentContent, DocumentMeta
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DocProviderBase(ABC):
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def iter(self) -> Iterator[DocumentMeta]:
|
|
10
|
+
"""Iterates metadata of files."""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def get_content(self, file_path: str) -> DocumentContent:
|
|
14
|
+
"""Gets the actual file content."""
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import Iterator, List
|
|
5
|
+
|
|
6
|
+
from codeembed.doc_provider.base import DocProviderBase
|
|
7
|
+
from codeembed.doc_provider.models import DocumentContent, DocumentMeta
|
|
8
|
+
|
|
9
|
+
_SKIP_DIRS = frozenset({"venv", ".venv", "node_modules", "dist", "build"})
|
|
10
|
+
_SKIP_FILES = frozenset({"__init__.py", ".env", ".env.local", "appsettings.json", "appsettings.Development.json"})
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_git_files(base_path: str) -> set[str]:
|
|
14
|
+
result = subprocess.run(
|
|
15
|
+
["git", "ls-files", "--cached", "--others", "--exclude-standard"],
|
|
16
|
+
cwd=base_path,
|
|
17
|
+
capture_output=True,
|
|
18
|
+
text=True,
|
|
19
|
+
)
|
|
20
|
+
return set(result.stdout.splitlines())
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LocalDocProvider(DocProviderBase):
|
|
24
|
+
def __init__(self, base_path: str, supported_file_extensions: List[str]) -> None:
|
|
25
|
+
self._base_path = base_path
|
|
26
|
+
self._supported_file_extensions = [ext.lower().split(".")[-1] for ext in supported_file_extensions]
|
|
27
|
+
|
|
28
|
+
def iter(self) -> Iterator[DocumentMeta]:
|
|
29
|
+
|
|
30
|
+
file_paths = _get_git_files(self._base_path)
|
|
31
|
+
|
|
32
|
+
docs: List[DocumentMeta] = []
|
|
33
|
+
for file_path in file_paths:
|
|
34
|
+
ext = file_path.split(".")[-1]
|
|
35
|
+
if ext.lower() not in self._supported_file_extensions:
|
|
36
|
+
continue
|
|
37
|
+
|
|
38
|
+
parts = file_path.split("/")
|
|
39
|
+
if parts[-1] in _SKIP_FILES or any(d in _SKIP_DIRS for d in parts[:-1]):
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
modified_ts = os.path.getmtime(file_path)
|
|
44
|
+
modified_at = datetime.fromtimestamp(modified_ts, tz=timezone.utc)
|
|
45
|
+
except OSError:
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
docs.append(DocumentMeta(file_path=file_path, modified_at=modified_at))
|
|
49
|
+
|
|
50
|
+
docs.sort(key=lambda d: d.modified_at, reverse=True)
|
|
51
|
+
yield from docs
|
|
52
|
+
|
|
53
|
+
def get_content(self, file_path: str) -> DocumentContent:
|
|
54
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
55
|
+
content = f.read()
|
|
56
|
+
modified_ts = os.path.getmtime(file_path)
|
|
57
|
+
modified_at = datetime.fromtimestamp(modified_ts, tz=timezone.utc)
|
|
58
|
+
return DocumentContent(content=content, modified_at=modified_at)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
|
|
4
|
+
from codeembed.utils.checksum_utils import string_to_sha256
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class DocumentMeta:
|
|
9
|
+
file_path: str
|
|
10
|
+
modified_at: datetime
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class DocumentContent:
|
|
15
|
+
content: str
|
|
16
|
+
modified_at: datetime
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def sha256_checksum(self) -> str:
|
|
20
|
+
return string_to_sha256(self.content)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
|
|
3
|
+
from codeembed.utils.string_utils import truncate_string
|
|
4
|
+
from codeembed.vector_db.base import VectorDbBase
|
|
5
|
+
from codeembed.vector_db.models import Chunk
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DocSearchService:
|
|
9
|
+
"""
|
|
10
|
+
The service that searches for relevant content from vector database and formats it for LLM consumption.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
vector_db: VectorDbBase,
|
|
16
|
+
) -> None:
|
|
17
|
+
self._vector_db = vector_db
|
|
18
|
+
|
|
19
|
+
def search(self, query: str, top_n: int = 10) -> str:
|
|
20
|
+
"""Searches for relevant content from vector database and formats it for LLM consumption."""
|
|
21
|
+
chunks = self._vector_db.search(query, top_n)
|
|
22
|
+
|
|
23
|
+
chunks_by_file: Dict[str, List[Chunk]] = {}
|
|
24
|
+
|
|
25
|
+
for chunk in chunks:
|
|
26
|
+
if chunk.file_path not in chunks_by_file:
|
|
27
|
+
chunks_by_file[chunk.file_path] = []
|
|
28
|
+
chunks_by_file[chunk.file_path].append(chunk)
|
|
29
|
+
|
|
30
|
+
res = f"<SearchQuery>{query}</SearchQuery>\n"
|
|
31
|
+
res += f"<TopN>{top_n}</TopN>\n"
|
|
32
|
+
res += f"<Results chunkCount={len(chunks)} fileCount={len(chunks_by_file)}>\n"
|
|
33
|
+
for file_path, chunks in chunks_by_file.items():
|
|
34
|
+
res += f' <File path="{file_path}">\n'
|
|
35
|
+
for chunk in chunks:
|
|
36
|
+
# NOTE: Consider truncating by number of tokens.
|
|
37
|
+
raw_code = chunk.raw_code if chunk.raw_code else ""
|
|
38
|
+
res += " <Chunk>\n"
|
|
39
|
+
res += f" <Summary>\n{truncate_string(chunk.content, 4096)}\n </Summary>\n"
|
|
40
|
+
res += (
|
|
41
|
+
f' <RawCode lines="{chunk.line_start}-{chunk.line_end}">\n'
|
|
42
|
+
f"{truncate_string(raw_code, 4096)}\n"
|
|
43
|
+
f" </RawCode>\n"
|
|
44
|
+
)
|
|
45
|
+
res += " </Chunk>\n"
|
|
46
|
+
res += " </File>\n"
|
|
47
|
+
res += "</Results>\n"
|
|
48
|
+
return res
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
|
|
3
|
+
import tiktoken
|
|
4
|
+
|
|
5
|
+
from codeembed.doc_splitters.models import FileSegment
|
|
6
|
+
|
|
7
|
+
_encoder = tiktoken.get_encoding("o200k_base")
|
|
8
|
+
|
|
9
|
+
_SPLIT_KEYWORDS: Dict[str, List[str]] = {
|
|
10
|
+
"py": ["class ", "def "],
|
|
11
|
+
"md": ["## "],
|
|
12
|
+
"ts": [
|
|
13
|
+
"export function ",
|
|
14
|
+
"export class ",
|
|
15
|
+
"export const ",
|
|
16
|
+
"export interface ",
|
|
17
|
+
"export type ",
|
|
18
|
+
"export default ",
|
|
19
|
+
"function ",
|
|
20
|
+
"class ",
|
|
21
|
+
],
|
|
22
|
+
"tsx": [
|
|
23
|
+
"export function ",
|
|
24
|
+
"export class ",
|
|
25
|
+
"export const ",
|
|
26
|
+
"export interface ",
|
|
27
|
+
"export type ",
|
|
28
|
+
"export default ",
|
|
29
|
+
"function ",
|
|
30
|
+
"class ",
|
|
31
|
+
],
|
|
32
|
+
"js": ["export function ", "export class ", "export const ", "export default ", "function ", "class "],
|
|
33
|
+
"jsx": ["export function ", "export class ", "export const ", "export default ", "function ", "class "],
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _count_tokens(text: str) -> int:
|
|
38
|
+
return len(_encoder.encode(text))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _split_by_fixed_length(
|
|
42
|
+
content: str,
|
|
43
|
+
max_tokens: int = 512,
|
|
44
|
+
overlap_lines: int = 5,
|
|
45
|
+
line_offset: int = 0,
|
|
46
|
+
) -> List[FileSegment]:
|
|
47
|
+
lines = content.splitlines()
|
|
48
|
+
chunks: List[FileSegment] = []
|
|
49
|
+
chunk: List[str] = []
|
|
50
|
+
chunk_tokens = 0
|
|
51
|
+
chunk_start = 0
|
|
52
|
+
|
|
53
|
+
for i, line in enumerate(lines):
|
|
54
|
+
line_tokens = _count_tokens(line)
|
|
55
|
+
if chunk_tokens + line_tokens > max_tokens and chunk:
|
|
56
|
+
chunks.append(
|
|
57
|
+
FileSegment(
|
|
58
|
+
line_start=line_offset + chunk_start,
|
|
59
|
+
line_end=line_offset + i,
|
|
60
|
+
content="\n".join(chunk),
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
overlap = chunk[-overlap_lines:]
|
|
64
|
+
chunk = overlap
|
|
65
|
+
chunk_tokens = sum(_count_tokens(ln) for ln in chunk)
|
|
66
|
+
chunk_start = i - len(overlap)
|
|
67
|
+
chunk.append(line)
|
|
68
|
+
chunk_tokens += line_tokens
|
|
69
|
+
|
|
70
|
+
if chunk:
|
|
71
|
+
chunks.append(
|
|
72
|
+
FileSegment(
|
|
73
|
+
line_start=line_offset + chunk_start,
|
|
74
|
+
line_end=line_offset + len(lines),
|
|
75
|
+
content="\n".join(chunk),
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return chunks
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _detect_splits(content: str, split_keywords: List[str]) -> List[int]:
|
|
83
|
+
split_lines = []
|
|
84
|
+
for i, line in enumerate(content.splitlines()):
|
|
85
|
+
for keyword in split_keywords:
|
|
86
|
+
if line.startswith(keyword):
|
|
87
|
+
split_lines.append(i)
|
|
88
|
+
break
|
|
89
|
+
if not split_lines or split_lines[0] != 0:
|
|
90
|
+
split_lines.insert(0, 0)
|
|
91
|
+
return split_lines
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _apply_splits(
|
|
95
|
+
content: str,
|
|
96
|
+
split_lines: List[int],
|
|
97
|
+
max_tokens: int = 512,
|
|
98
|
+
overlap_lines: int = 5,
|
|
99
|
+
) -> List[FileSegment]:
|
|
100
|
+
segments = []
|
|
101
|
+
lines = content.splitlines()
|
|
102
|
+
|
|
103
|
+
for i in range(len(split_lines)):
|
|
104
|
+
split_start = split_lines[i]
|
|
105
|
+
split_end = split_lines[i + 1] if i + 1 < len(split_lines) else len(lines)
|
|
106
|
+
|
|
107
|
+
# Scan backwards to the nearest empty line so decorators/comments are included
|
|
108
|
+
actual_start = split_start
|
|
109
|
+
for j in range(split_start - 1, -1, -1):
|
|
110
|
+
if not lines[j].strip():
|
|
111
|
+
actual_start = j + 1
|
|
112
|
+
break
|
|
113
|
+
else:
|
|
114
|
+
if split_start > 0:
|
|
115
|
+
actual_start = 0
|
|
116
|
+
|
|
117
|
+
if actual_start == split_end:
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
segment_content = "\n".join(lines[actual_start:split_end])
|
|
121
|
+
|
|
122
|
+
if _count_tokens(segment_content) <= max_tokens:
|
|
123
|
+
segments.append(
|
|
124
|
+
FileSegment(
|
|
125
|
+
line_start=actual_start,
|
|
126
|
+
line_end=split_end,
|
|
127
|
+
content=segment_content,
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
segments.extend(
|
|
132
|
+
_split_by_fixed_length(
|
|
133
|
+
segment_content,
|
|
134
|
+
max_tokens=max_tokens,
|
|
135
|
+
overlap_lines=overlap_lines,
|
|
136
|
+
line_offset=actual_start,
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return segments
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class FileSplitter:
|
|
144
|
+
def __init__(self, max_tokens: int = 512, overlap_lines: int = 5):
|
|
145
|
+
self._max_tokens = max_tokens
|
|
146
|
+
self._overlap_lines = overlap_lines
|
|
147
|
+
|
|
148
|
+
def split_file(self, file_content: str, file_path: str) -> List[FileSegment]:
|
|
149
|
+
|
|
150
|
+
file_extension = file_path.split(".")[-1].lower()
|
|
151
|
+
|
|
152
|
+
if file_extension not in _SPLIT_KEYWORDS:
|
|
153
|
+
return _split_by_fixed_length(
|
|
154
|
+
file_content,
|
|
155
|
+
max_tokens=self._max_tokens,
|
|
156
|
+
overlap_lines=self._overlap_lines,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
split_lines = _detect_splits(file_content, _SPLIT_KEYWORDS[file_extension])
|
|
160
|
+
return _apply_splits(
|
|
161
|
+
file_content,
|
|
162
|
+
split_lines,
|
|
163
|
+
max_tokens=self._max_tokens,
|
|
164
|
+
overlap_lines=self._overlap_lines,
|
|
165
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from codeembed.llm.base import LLMServiceBase
|
|
2
|
+
from codeembed.llm.models import ChatMessage, LLMResponse, StructuredLLMResponse
|
|
3
|
+
from codeembed.llm.ollama_adapter import OllamaLLMService
|
|
4
|
+
from codeembed.llm.openai_adapter import OpenAILLMService
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ChatMessage",
|
|
8
|
+
"LLMResponse",
|
|
9
|
+
"LLMServiceBase",
|
|
10
|
+
"OllamaLLMService",
|
|
11
|
+
"OpenAILLMService",
|
|
12
|
+
"StructuredLLMResponse",
|
|
13
|
+
]
|
codeembed/llm/base.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Optional, Type, TypeVar
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from codeembed.llm.models import ChatMessage, LLMResponse, StructuredLLMResponse
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T", bound=BaseModel)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LLMServiceBase(ABC):
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def generate_structured_output(
|
|
14
|
+
self,
|
|
15
|
+
messages: List[ChatMessage],
|
|
16
|
+
llm_model: str,
|
|
17
|
+
output_format: Type[T],
|
|
18
|
+
max_tokens: Optional[int] = None,
|
|
19
|
+
temperature: Optional[float] = None,
|
|
20
|
+
) -> StructuredLLMResponse[T]:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def generate_response(
|
|
25
|
+
self,
|
|
26
|
+
messages: List[ChatMessage],
|
|
27
|
+
llm_model: str,
|
|
28
|
+
max_tokens: Optional[int] = None,
|
|
29
|
+
temperature: Optional[float] = None,
|
|
30
|
+
) -> LLMResponse:
|
|
31
|
+
pass
|
codeembed/llm/models.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Generic, Literal, TypedDict, TypeVar
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T", bound=BaseModel)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ChatMessage(TypedDict):
|
|
10
|
+
role: Literal["system", "user", "assistant"]
|
|
11
|
+
content: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class LLMResponse:
|
|
16
|
+
input_tokens: int
|
|
17
|
+
output_tokens: int
|
|
18
|
+
response: str
|
|
19
|
+
llm_model: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class StructuredLLMResponse(Generic[T]):
|
|
24
|
+
input_tokens: int
|
|
25
|
+
output_tokens: int
|
|
26
|
+
data: T
|
|
27
|
+
llm_model: str
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import List, Optional, Type, TypeVar
|
|
2
|
+
|
|
3
|
+
import ollama
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from codeembed.llm.base import LLMServiceBase
|
|
7
|
+
from codeembed.llm.models import ChatMessage, LLMResponse, StructuredLLMResponse
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T", bound=BaseModel)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OllamaLLMService(LLMServiceBase):
|
|
13
|
+
def generate_structured_output(
|
|
14
|
+
self,
|
|
15
|
+
messages: List[ChatMessage],
|
|
16
|
+
llm_model: str,
|
|
17
|
+
output_format: Type[T],
|
|
18
|
+
max_tokens: Optional[int] = None,
|
|
19
|
+
temperature: Optional[float] = None,
|
|
20
|
+
) -> StructuredLLMResponse[T]:
|
|
21
|
+
|
|
22
|
+
options = {}
|
|
23
|
+
if max_tokens is not None:
|
|
24
|
+
options["num_predict"] = max_tokens
|
|
25
|
+
if temperature is not None:
|
|
26
|
+
options["temperature"] = temperature
|
|
27
|
+
|
|
28
|
+
resp = ollama.chat(model=llm_model, messages=messages, format="json", options=options)
|
|
29
|
+
|
|
30
|
+
data = resp["message"]["content"]
|
|
31
|
+
|
|
32
|
+
model = output_format.model_validate_json(data)
|
|
33
|
+
|
|
34
|
+
return StructuredLLMResponse(
|
|
35
|
+
input_tokens=resp["prompt_eval_count"] or 0,
|
|
36
|
+
output_tokens=resp["eval_count"] or 0,
|
|
37
|
+
data=model,
|
|
38
|
+
llm_model=llm_model,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def generate_response(
|
|
42
|
+
self,
|
|
43
|
+
messages: List[ChatMessage],
|
|
44
|
+
llm_model: str,
|
|
45
|
+
max_tokens: Optional[int] = None,
|
|
46
|
+
temperature: Optional[float] = None,
|
|
47
|
+
) -> LLMResponse:
|
|
48
|
+
|
|
49
|
+
options = {}
|
|
50
|
+
if max_tokens is not None:
|
|
51
|
+
options["num_predict"] = max_tokens
|
|
52
|
+
if temperature is not None:
|
|
53
|
+
options["temperature"] = temperature
|
|
54
|
+
|
|
55
|
+
resp = ollama.chat(model=llm_model, messages=messages, options=options)
|
|
56
|
+
|
|
57
|
+
content = resp["message"]["content"]
|
|
58
|
+
|
|
59
|
+
return LLMResponse(
|
|
60
|
+
input_tokens=resp["prompt_eval_count"] or 0,
|
|
61
|
+
output_tokens=resp["eval_count"] or 0,
|
|
62
|
+
response=content,
|
|
63
|
+
llm_model=llm_model,
|
|
64
|
+
)
|