amsdal_ml 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amsdal_ml/Third-Party Materials - AMSDAL Dependencies - License Notices.md +617 -0
- amsdal_ml/__about__.py +1 -1
- amsdal_ml/agents/__init__.py +13 -0
- amsdal_ml/agents/agent.py +5 -7
- amsdal_ml/agents/default_qa_agent.py +108 -143
- amsdal_ml/agents/functional_calling_agent.py +233 -0
- amsdal_ml/agents/mcp_client_tool.py +46 -0
- amsdal_ml/agents/python_tool.py +86 -0
- amsdal_ml/agents/retriever_tool.py +5 -6
- amsdal_ml/agents/tool_adapters.py +98 -0
- amsdal_ml/fileio/base_loader.py +7 -5
- amsdal_ml/fileio/openai_loader.py +16 -17
- amsdal_ml/mcp_client/base.py +2 -0
- amsdal_ml/mcp_client/http_client.py +7 -1
- amsdal_ml/mcp_client/stdio_client.py +19 -16
- amsdal_ml/mcp_server/server_retriever_stdio.py +8 -11
- amsdal_ml/ml_ingesting/__init__.py +29 -0
- amsdal_ml/ml_ingesting/default_ingesting.py +49 -51
- amsdal_ml/ml_ingesting/embedders/__init__.py +4 -0
- amsdal_ml/ml_ingesting/embedders/embedder.py +12 -0
- amsdal_ml/ml_ingesting/embedders/openai_embedder.py +30 -0
- amsdal_ml/ml_ingesting/embedding_data.py +3 -0
- amsdal_ml/ml_ingesting/loaders/__init__.py +6 -0
- amsdal_ml/ml_ingesting/loaders/folder_loader.py +52 -0
- amsdal_ml/ml_ingesting/loaders/loader.py +28 -0
- amsdal_ml/ml_ingesting/loaders/pdf_loader.py +136 -0
- amsdal_ml/ml_ingesting/loaders/text_loader.py +44 -0
- amsdal_ml/ml_ingesting/model_ingester.py +278 -0
- amsdal_ml/ml_ingesting/pipeline.py +131 -0
- amsdal_ml/ml_ingesting/pipeline_interface.py +31 -0
- amsdal_ml/ml_ingesting/processors/__init__.py +4 -0
- amsdal_ml/ml_ingesting/processors/cleaner.py +14 -0
- amsdal_ml/ml_ingesting/processors/text_cleaner.py +42 -0
- amsdal_ml/ml_ingesting/splitters/__init__.py +4 -0
- amsdal_ml/ml_ingesting/splitters/splitter.py +15 -0
- amsdal_ml/ml_ingesting/splitters/token_splitter.py +85 -0
- amsdal_ml/ml_ingesting/stores/__init__.py +4 -0
- amsdal_ml/ml_ingesting/stores/embedding_data.py +63 -0
- amsdal_ml/ml_ingesting/stores/store.py +22 -0
- amsdal_ml/ml_ingesting/types.py +40 -0
- amsdal_ml/ml_models/models.py +96 -4
- amsdal_ml/ml_models/openai_model.py +430 -122
- amsdal_ml/ml_models/utils.py +7 -0
- amsdal_ml/ml_retrievers/__init__.py +17 -0
- amsdal_ml/ml_retrievers/adapters.py +93 -0
- amsdal_ml/ml_retrievers/default_retriever.py +11 -1
- amsdal_ml/ml_retrievers/openai_retriever.py +27 -7
- amsdal_ml/ml_retrievers/query_retriever.py +487 -0
- amsdal_ml/ml_retrievers/retriever.py +12 -0
- amsdal_ml/models/embedding_model.py +7 -7
- amsdal_ml/prompts/__init__.py +77 -0
- amsdal_ml/prompts/database_query_agent.prompt +14 -0
- amsdal_ml/prompts/functional_calling_agent_base.prompt +9 -0
- amsdal_ml/prompts/nl_query_filter.prompt +318 -0
- amsdal_ml/{agents/promts → prompts}/react_chat.prompt +17 -8
- amsdal_ml/utils/__init__.py +5 -0
- amsdal_ml/utils/query_utils.py +189 -0
- {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.1.dist-info}/METADATA +61 -3
- amsdal_ml-0.2.1.dist-info/RECORD +72 -0
- {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.1.dist-info}/WHEEL +1 -1
- amsdal_ml/agents/promts/__init__.py +0 -58
- amsdal_ml-0.1.4.dist-info/RECORD +0 -39
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import base64
|
|
5
|
+
import logging
|
|
5
6
|
import os
|
|
6
7
|
from collections.abc import Iterable
|
|
7
8
|
from contextlib import AsyncExitStack
|
|
@@ -16,6 +17,8 @@ from mcp.shared.exceptions import McpError
|
|
|
16
17
|
from amsdal_ml.mcp_client.base import ToolClient
|
|
17
18
|
from amsdal_ml.mcp_client.base import ToolInfo
|
|
18
19
|
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
19
22
|
|
|
20
23
|
class StdioClient(ToolClient):
|
|
21
24
|
"""
|
|
@@ -23,22 +26,22 @@ class StdioClient(ToolClient):
|
|
|
23
26
|
"""
|
|
24
27
|
|
|
25
28
|
def __init__(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
self,
|
|
30
|
+
alias: str,
|
|
31
|
+
module_or_cmd: str,
|
|
32
|
+
*args: str,
|
|
33
|
+
persist_session: bool = True,
|
|
34
|
+
send_amsdal_config: bool = True,
|
|
32
35
|
):
|
|
33
36
|
self.alias = alias
|
|
34
|
-
if module_or_cmd in (
|
|
37
|
+
if module_or_cmd in ('python', 'python3'):
|
|
35
38
|
self._command = module_or_cmd
|
|
36
39
|
self._args = list(args)
|
|
37
40
|
else:
|
|
38
|
-
self._command =
|
|
39
|
-
self._args = [
|
|
41
|
+
self._command = 'python'
|
|
42
|
+
self._args = ['-m', module_or_cmd]
|
|
40
43
|
if send_amsdal_config:
|
|
41
|
-
self._args.append(
|
|
44
|
+
self._args.append('--amsdal-config')
|
|
42
45
|
self._args.append(self._build_amsdal_config_arg())
|
|
43
46
|
self._persist = persist_session
|
|
44
47
|
self._lock = asyncio.Lock()
|
|
@@ -82,8 +85,8 @@ class StdioClient(ToolClient):
|
|
|
82
85
|
ToolInfo(
|
|
83
86
|
alias=alias,
|
|
84
87
|
name=t.name,
|
|
85
|
-
description=(getattr(t,
|
|
86
|
-
input_schema=(getattr(t,
|
|
88
|
+
description=(getattr(t, 'description', None) or ''),
|
|
89
|
+
input_schema=(getattr(t, 'inputSchema', None) or {}),
|
|
87
90
|
)
|
|
88
91
|
for t in resp_tools
|
|
89
92
|
]
|
|
@@ -127,20 +130,20 @@ class StdioClient(ToolClient):
|
|
|
127
130
|
rx, tx = await stack.enter_async_context(stdio_client(params))
|
|
128
131
|
s = await stack.enter_async_context(ClientSession(rx, tx))
|
|
129
132
|
await s.initialize()
|
|
130
|
-
|
|
133
|
+
logger.debug("Calling tool: %s with args: %s", tool_name, args)
|
|
131
134
|
res = await self._call_with_timeout(s.call_tool(tool_name, args), timeout=timeout)
|
|
132
|
-
return getattr(res,
|
|
135
|
+
return getattr(res, 'content', res)
|
|
133
136
|
|
|
134
137
|
# Persistent session path
|
|
135
138
|
s = await self._ensure_session()
|
|
136
139
|
try:
|
|
137
140
|
res = await self._call_with_timeout(s.call_tool(tool_name, args), timeout=timeout)
|
|
138
|
-
return getattr(res,
|
|
141
|
+
return getattr(res, 'content', res)
|
|
139
142
|
except (TimeoutError, McpError):
|
|
140
143
|
await self._reset_session()
|
|
141
144
|
s = await self._ensure_session()
|
|
142
145
|
res = await self._call_with_timeout(s.call_tool(tool_name, args), timeout=timeout)
|
|
143
|
-
return getattr(res,
|
|
146
|
+
return getattr(res, 'content', res)
|
|
144
147
|
|
|
145
148
|
def _build_amsdal_config_arg(self) -> str:
|
|
146
149
|
"""
|
|
@@ -18,32 +18,29 @@ from amsdal_ml.agents.retriever_tool import retriever_search
|
|
|
18
18
|
logging.basicConfig(
|
|
19
19
|
level=logging.INFO,
|
|
20
20
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
21
|
-
handlers=[
|
|
22
|
-
logging.FileHandler("server.log"),
|
|
23
|
-
logging.StreamHandler(sys.stdout)
|
|
24
|
-
]
|
|
21
|
+
handlers=[logging.FileHandler('server.log'), logging.StreamHandler(sys.stdout)],
|
|
25
22
|
)
|
|
26
23
|
|
|
27
24
|
parser = argparse.ArgumentParser()
|
|
28
|
-
parser.add_argument(
|
|
25
|
+
parser.add_argument('--amsdal-config', required=False, help='Base64-encoded config string')
|
|
29
26
|
args = parser.parse_args()
|
|
30
27
|
|
|
31
|
-
logging.info(f
|
|
28
|
+
logging.info(f'Starting server with args: {args}')
|
|
32
29
|
|
|
33
30
|
if args.amsdal_config:
|
|
34
|
-
decoded = base64.b64decode(args.amsdal_config).decode(
|
|
31
|
+
decoded = base64.b64decode(args.amsdal_config).decode('utf-8')
|
|
35
32
|
amsdal_config = AmsdalConfig(**json.loads(decoded))
|
|
36
|
-
logging.info(f
|
|
33
|
+
logging.info(f'Loaded Amsdal config: {amsdal_config}')
|
|
37
34
|
AmsdalConfigManager().set_config(amsdal_config)
|
|
38
35
|
|
|
39
36
|
manager: Any
|
|
40
37
|
if amsdal_config.async_mode:
|
|
41
38
|
manager = AsyncAmsdalManager()
|
|
42
|
-
logging.info(
|
|
39
|
+
logging.info('pre-setup')
|
|
43
40
|
asyncio.run(cast(Any, manager).setup())
|
|
44
|
-
logging.info(
|
|
41
|
+
logging.info('post-setup')
|
|
45
42
|
asyncio.run(cast(Any, manager).post_setup())
|
|
46
|
-
logging.info(
|
|
43
|
+
logging.info('manager inited')
|
|
47
44
|
else:
|
|
48
45
|
manager = AmsdalManager()
|
|
49
46
|
cast(Any, manager).setup()
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from amsdal_ml.ml_ingesting.embedders.embedder import Embedder
|
|
2
|
+
from amsdal_ml.ml_ingesting.loaders.loader import Loader
|
|
3
|
+
from amsdal_ml.ml_ingesting.loaders.text_loader import TextLoader
|
|
4
|
+
from amsdal_ml.ml_ingesting.model_ingester import ModelIngester
|
|
5
|
+
from amsdal_ml.ml_ingesting.pipeline import DefaultIngestionPipeline
|
|
6
|
+
from amsdal_ml.ml_ingesting.pipeline_interface import IngestionPipeline
|
|
7
|
+
from amsdal_ml.ml_ingesting.processors.cleaner import Cleaner
|
|
8
|
+
from amsdal_ml.ml_ingesting.splitters.splitter import Splitter
|
|
9
|
+
from amsdal_ml.ml_ingesting.stores.store import EmbeddingStore
|
|
10
|
+
from amsdal_ml.ml_ingesting.types import IngestionSource
|
|
11
|
+
from amsdal_ml.ml_ingesting.types import LoadedDocument
|
|
12
|
+
from amsdal_ml.ml_ingesting.types import LoadedPage
|
|
13
|
+
from amsdal_ml.ml_ingesting.types import TextChunk
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
'Cleaner',
|
|
17
|
+
'DefaultIngestionPipeline',
|
|
18
|
+
'Embedder',
|
|
19
|
+
'EmbeddingStore',
|
|
20
|
+
'IngestionPipeline',
|
|
21
|
+
'IngestionSource',
|
|
22
|
+
'LoadedDocument',
|
|
23
|
+
'LoadedPage',
|
|
24
|
+
'Loader',
|
|
25
|
+
'ModelIngester',
|
|
26
|
+
'Splitter',
|
|
27
|
+
'TextChunk',
|
|
28
|
+
'TextLoader',
|
|
29
|
+
]
|
|
@@ -25,7 +25,7 @@ _MIN_WORDS_PER_SENT = 4
|
|
|
25
25
|
|
|
26
26
|
class DepthLimitReached(str):
|
|
27
27
|
def __str__(self) -> str:
|
|
28
|
-
return
|
|
28
|
+
return 'Truncated due to reached depth limit'
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
@dataclass
|
|
@@ -33,17 +33,17 @@ class VisitedObject:
|
|
|
33
33
|
obj: Any
|
|
34
34
|
|
|
35
35
|
def __str__(self) -> str:
|
|
36
|
-
return f
|
|
36
|
+
return f'Recursion reference to object {self.obj}'
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class MissingRelation(str):
|
|
40
40
|
def __str__(self) -> str:
|
|
41
|
-
return
|
|
41
|
+
return 'Relation not present'
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
class NoChildren(str):
|
|
45
45
|
def __str__(self) -> str:
|
|
46
|
-
return
|
|
46
|
+
return 'No nested data'
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
# UP007: use X | Y style
|
|
@@ -79,109 +79,109 @@ class DefaultIngesting(MLIngesting):
|
|
|
79
79
|
self._afacts_transform = afacts_transform
|
|
80
80
|
|
|
81
81
|
def _default_header(self, instance: Any, facts: list[str]) -> str:
|
|
82
|
-
doc = getattr(instance.__class__,
|
|
83
|
-
return (doc.strip() +
|
|
82
|
+
doc = getattr(instance.__class__, '__doc__', '') or f'Instance of {instance.__class__.__name__}'
|
|
83
|
+
return (doc.strip() + '\n\nKey facts:\n' + '\n'.join(facts)).strip()
|
|
84
84
|
|
|
85
85
|
def _walk_sync(self, obj: Any, depth: int, visited: set[tuple[str, str]]) -> list[str | Marker]:
|
|
86
86
|
if depth > self.max_depth:
|
|
87
|
-
return [DepthLimitReached(
|
|
88
|
-
key = (obj.__class__.__name__, str(getattr(obj,
|
|
87
|
+
return [DepthLimitReached('')]
|
|
88
|
+
key = (obj.__class__.__name__, str(getattr(obj, 'object_id', id(obj))))
|
|
89
89
|
if key in visited:
|
|
90
90
|
return [VisitedObject(obj)]
|
|
91
91
|
visited.add(key)
|
|
92
92
|
|
|
93
93
|
out: list[str | Marker] = []
|
|
94
|
-
fields = getattr(obj.__class__,
|
|
95
|
-
for name, field in getattr(fields,
|
|
94
|
+
fields = getattr(obj.__class__, 'model_fields', {})
|
|
95
|
+
for name, field in getattr(fields, 'items', lambda: [])():
|
|
96
96
|
try:
|
|
97
97
|
v = getattr(obj, name)
|
|
98
|
-
title = getattr(field,
|
|
98
|
+
title = getattr(field, 'title', None) or name.replace('_', ' ').capitalize()
|
|
99
99
|
if v is None:
|
|
100
100
|
continue
|
|
101
101
|
if isinstance(v, str | int | float | bool | date):
|
|
102
|
-
out.append(f
|
|
103
|
-
elif hasattr(v.__class__,
|
|
102
|
+
out.append(f'{title}: {v}')
|
|
103
|
+
elif hasattr(v.__class__, 'model_fields'):
|
|
104
104
|
sub = self._walk_sync(v, depth + 1, visited)
|
|
105
|
-
out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren(
|
|
105
|
+
out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren('')))
|
|
106
106
|
elif isinstance(v, list):
|
|
107
107
|
simple = [str(x) for x in v if isinstance(x, str | int | float)]
|
|
108
108
|
if simple:
|
|
109
109
|
out.append(f'{title}: {", ".join(simple)}')
|
|
110
110
|
except Exception as e: # noqa: BLE001
|
|
111
|
-
logger.warning(f
|
|
111
|
+
logger.warning(f'[walk_sync] field {name}: {e}')
|
|
112
112
|
|
|
113
|
-
fks = getattr(obj.__class__,
|
|
113
|
+
fks = getattr(obj.__class__, 'FOREIGN_KEYS', [])
|
|
114
114
|
if not fks and not out:
|
|
115
|
-
out.append(NoChildren(
|
|
115
|
+
out.append(NoChildren(''))
|
|
116
116
|
for fk in fks:
|
|
117
117
|
try:
|
|
118
118
|
rel = getattr(obj, fk, None)
|
|
119
119
|
if rel is None:
|
|
120
|
-
out.append(MissingRelation(
|
|
120
|
+
out.append(MissingRelation(''))
|
|
121
121
|
continue
|
|
122
122
|
if isinstance(rel, list):
|
|
123
123
|
for i, item in enumerate(rel):
|
|
124
|
-
if hasattr(item.__class__,
|
|
124
|
+
if hasattr(item.__class__, 'model_fields'):
|
|
125
125
|
sub = self._walk_sync(item, depth + 1, visited)
|
|
126
126
|
out.append(f'{fk}[{i}] → {"; ".join(map(str, sub))}')
|
|
127
|
-
elif hasattr(rel.__class__,
|
|
127
|
+
elif hasattr(rel.__class__, 'model_fields'):
|
|
128
128
|
sub = self._walk_sync(rel, depth + 1, visited)
|
|
129
129
|
out.append(f'{fk} → {"; ".join(map(str, sub))}')
|
|
130
130
|
except Exception as e: # noqa: BLE001
|
|
131
|
-
logger.warning(f
|
|
131
|
+
logger.warning(f'[walk_sync] FK {fk}: {e}')
|
|
132
132
|
return out
|
|
133
133
|
|
|
134
134
|
async def _walk_async(self, obj: Any, depth: int, visited: set[tuple[str, str]]) -> list[str | Marker]:
|
|
135
135
|
if depth > self.max_depth:
|
|
136
|
-
return [DepthLimitReached(
|
|
137
|
-
key = (obj.__class__.__name__, str(getattr(obj,
|
|
136
|
+
return [DepthLimitReached('')]
|
|
137
|
+
key = (obj.__class__.__name__, str(getattr(obj, 'object_id', id(obj))))
|
|
138
138
|
if key in visited:
|
|
139
139
|
return [VisitedObject(obj)]
|
|
140
140
|
visited.add(key)
|
|
141
141
|
|
|
142
142
|
out: list[str | Marker] = []
|
|
143
|
-
fields = getattr(obj.__class__,
|
|
144
|
-
for name, field in getattr(fields,
|
|
143
|
+
fields = getattr(obj.__class__, 'model_fields', {})
|
|
144
|
+
for name, field in getattr(fields, 'items', lambda: [])():
|
|
145
145
|
try:
|
|
146
146
|
v = getattr(obj, name)
|
|
147
147
|
if asyncio.iscoroutine(v):
|
|
148
148
|
v = await v
|
|
149
|
-
title = getattr(field,
|
|
149
|
+
title = getattr(field, 'title', None) or name.replace('_', ' ').capitalize()
|
|
150
150
|
if v is None:
|
|
151
151
|
continue
|
|
152
152
|
if isinstance(v, str | int | float | bool | date):
|
|
153
|
-
out.append(f
|
|
154
|
-
elif hasattr(v.__class__,
|
|
153
|
+
out.append(f'{title}: {v}')
|
|
154
|
+
elif hasattr(v.__class__, 'model_fields'):
|
|
155
155
|
sub = await self._walk_async(v, depth + 1, visited)
|
|
156
|
-
out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren(
|
|
156
|
+
out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren('')))
|
|
157
157
|
elif isinstance(v, list):
|
|
158
158
|
simple = [str(x) for x in v if isinstance(x, str | int | float)]
|
|
159
159
|
if simple:
|
|
160
160
|
out.append(f'{title}: {", ".join(simple)}')
|
|
161
161
|
except Exception as e: # noqa: BLE001
|
|
162
|
-
logger.warning(f
|
|
162
|
+
logger.warning(f'[walk_async] field {name}: {e}')
|
|
163
163
|
|
|
164
|
-
fks = getattr(obj.__class__,
|
|
164
|
+
fks = getattr(obj.__class__, 'FOREIGN_KEYS', [])
|
|
165
165
|
if not fks and not out:
|
|
166
|
-
out.append(NoChildren(
|
|
166
|
+
out.append(NoChildren(''))
|
|
167
167
|
for fk in fks:
|
|
168
168
|
try:
|
|
169
169
|
rel = getattr(obj, fk, None)
|
|
170
170
|
if asyncio.iscoroutine(rel):
|
|
171
171
|
rel = await rel
|
|
172
172
|
if rel is None:
|
|
173
|
-
out.append(MissingRelation(
|
|
173
|
+
out.append(MissingRelation(''))
|
|
174
174
|
continue
|
|
175
175
|
if isinstance(rel, list):
|
|
176
176
|
for i, item in enumerate(rel):
|
|
177
|
-
if hasattr(item.__class__,
|
|
177
|
+
if hasattr(item.__class__, 'model_fields'):
|
|
178
178
|
sub = await self._walk_async(item, depth + 1, visited)
|
|
179
179
|
out.append(f'{fk}[{i}] → {"; ".join(map(str, sub))}')
|
|
180
|
-
elif hasattr(rel.__class__,
|
|
180
|
+
elif hasattr(rel.__class__, 'model_fields'):
|
|
181
181
|
sub = await self._walk_async(rel, depth + 1, visited)
|
|
182
182
|
out.append(f'{fk} → {"; ".join(map(str, sub))}')
|
|
183
183
|
except Exception as e: # noqa: BLE001
|
|
184
|
-
logger.warning(f
|
|
184
|
+
logger.warning(f'[walk_async] FK {fk}: {e}')
|
|
185
185
|
return out
|
|
186
186
|
|
|
187
187
|
def collect_facts(self, instance: Any) -> list[str | Marker]:
|
|
@@ -221,38 +221,36 @@ class DefaultIngesting(MLIngesting):
|
|
|
221
221
|
return list(self._tags)
|
|
222
222
|
|
|
223
223
|
def _split(self, text: str, max_sentences: int = 7) -> list[str]:
|
|
224
|
-
sents = re.split(r
|
|
224
|
+
sents = re.split(r'(?<=[.!?])\s+', text.strip())
|
|
225
225
|
sents = [s.strip() for s in sents if len(s.split()) >= _MIN_WORDS_PER_SENT]
|
|
226
226
|
chunks: list[str] = []
|
|
227
227
|
cur: list[str] = []
|
|
228
228
|
for s in sents:
|
|
229
|
-
proposal = (
|
|
229
|
+
proposal = (' '.join([*cur, s])).strip()
|
|
230
230
|
if self._token_len_fn(proposal) <= self.max_tokens_per_chunk and len(cur) < max_sentences:
|
|
231
231
|
cur.append(s)
|
|
232
232
|
else:
|
|
233
233
|
if cur:
|
|
234
|
-
ch =
|
|
235
|
-
if ch and not ch.endswith(
|
|
236
|
-
ch +=
|
|
234
|
+
ch = ' '.join(cur).strip()
|
|
235
|
+
if ch and not ch.endswith('.'):
|
|
236
|
+
ch += '.'
|
|
237
237
|
chunks.append(ch)
|
|
238
238
|
cur = [s]
|
|
239
239
|
if cur:
|
|
240
|
-
ch =
|
|
241
|
-
if ch and not ch.endswith(
|
|
242
|
-
ch +=
|
|
240
|
+
ch = ' '.join(cur).strip()
|
|
241
|
+
if ch and not ch.endswith('.'):
|
|
242
|
+
ch += '.'
|
|
243
243
|
chunks.append(ch)
|
|
244
244
|
return chunks
|
|
245
245
|
|
|
246
246
|
def _resolve_link(self, instance: Any) -> tuple[str, str]:
|
|
247
247
|
cls = instance.__class__.__name__
|
|
248
|
-
oid = getattr(instance,
|
|
248
|
+
oid = getattr(instance, 'object_id', None)
|
|
249
249
|
if oid is None:
|
|
250
|
-
oid = str(getattr(instance,
|
|
250
|
+
oid = str(getattr(instance, 'id', None) or id(instance))
|
|
251
251
|
return cls, str(oid)
|
|
252
252
|
|
|
253
|
-
def _make_records(
|
|
254
|
-
self, chunks: list[str], vectors: list[list[float]], tags: list[str]
|
|
255
|
-
) -> list[EmbeddingData]:
|
|
253
|
+
def _make_records(self, chunks: list[str], vectors: list[list[float]], tags: list[str]) -> list[EmbeddingData]:
|
|
256
254
|
out: list[EmbeddingData] = []
|
|
257
255
|
for i, (t, v) in enumerate(zip(chunks[: self.max_chunks], vectors, strict=False)):
|
|
258
256
|
out.append(EmbeddingData(chunk_index=i, raw_text=t, embedding=v, tags=tags))
|
|
@@ -262,7 +260,7 @@ class DefaultIngesting(MLIngesting):
|
|
|
262
260
|
self, instance: Any, embed_func: Callable[[str], list[float]] | None = None
|
|
263
261
|
) -> list[EmbeddingData]:
|
|
264
262
|
if embed_func is None:
|
|
265
|
-
msg =
|
|
263
|
+
msg = 'embed_func is required for DefaultIngesting.generate_embeddings'
|
|
266
264
|
raise RuntimeError(msg)
|
|
267
265
|
text = self.generate_text(instance)
|
|
268
266
|
chunks = self._split(text)
|
|
@@ -274,7 +272,7 @@ class DefaultIngesting(MLIngesting):
|
|
|
274
272
|
self, instance: Any, embed_func: Callable[[str], Awaitable[list[float]]] | None = None
|
|
275
273
|
) -> list[EmbeddingData]:
|
|
276
274
|
if embed_func is None:
|
|
277
|
-
msg =
|
|
275
|
+
msg = 'embed_func is required for DefaultIngesting.agenerate_embeddings'
|
|
278
276
|
raise RuntimeError(msg)
|
|
279
277
|
text = await self.agenerate_text(instance)
|
|
280
278
|
chunks = self._split(text)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Embedder(ABC):
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def embed(self, text: str) -> list[float]: ...
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def aembed(self, text: str) -> list[float]: ...
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from openai import AsyncOpenAI
|
|
6
|
+
from openai import OpenAI
|
|
7
|
+
|
|
8
|
+
from amsdal_ml.ml_config import ml_config
|
|
9
|
+
from amsdal_ml.ml_ingesting.embedders.embedder import Embedder
|
|
10
|
+
|
|
11
|
+
DEFAULT_EMBED_MODEL = ml_config.embed_model_name
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OpenAIEmbedder(Embedder):
|
|
15
|
+
def __init__(self, *, api_key: str | None = None, embed_model: str | None = None) -> None:
|
|
16
|
+
self.api_key = api_key or ml_config.resolved_openai_key or os.getenv('OPENAI_API_KEY')
|
|
17
|
+
if not self.api_key:
|
|
18
|
+
msg = 'OPENAI_API_KEY is required for OpenAIEmbedder'
|
|
19
|
+
raise RuntimeError(msg)
|
|
20
|
+
self.embed_model = embed_model or DEFAULT_EMBED_MODEL
|
|
21
|
+
self.client = OpenAI(api_key=self.api_key)
|
|
22
|
+
self.aclient = AsyncOpenAI(api_key=self.api_key)
|
|
23
|
+
|
|
24
|
+
def embed(self, text: str) -> list[float]:
|
|
25
|
+
resp = self.client.embeddings.create(model=self.embed_model, input=text)
|
|
26
|
+
return resp.data[0].embedding
|
|
27
|
+
|
|
28
|
+
async def aembed(self, text: str) -> list[float]:
|
|
29
|
+
resp = await self.aclient.embeddings.create(model=self.embed_model, input=text)
|
|
30
|
+
return resp.data[0].embedding
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel
|
|
2
4
|
from pydantic import Field
|
|
3
5
|
|
|
@@ -7,3 +9,4 @@ class EmbeddingData(BaseModel):
|
|
|
7
9
|
raw_text: str = Field(..., title='Raw text used for embedding')
|
|
8
10
|
embedding: list[float] = Field(..., title='Vector embedding')
|
|
9
11
|
tags: list[str] = Field(default_factory=list, title='Embedding tags')
|
|
12
|
+
metadata: dict[str, Any] = Field(default_factory=dict, title='Embedding metadata')
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from amsdal_ml.ml_ingesting.loaders.folder_loader import FolderLoader
|
|
2
|
+
from amsdal_ml.ml_ingesting.loaders.folder_loader import PdfFolderLoader
|
|
3
|
+
from amsdal_ml.ml_ingesting.loaders.loader import Loader
|
|
4
|
+
from amsdal_ml.ml_ingesting.loaders.pdf_loader import PdfLoader
|
|
5
|
+
|
|
6
|
+
__all__ = ['FolderLoader', 'Loader', 'PdfFolderLoader', 'PdfLoader']
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from amsdal_ml.ml_ingesting.loaders.loader import Loader
|
|
8
|
+
from amsdal_ml.ml_ingesting.loaders.pdf_loader import PdfLoader
|
|
9
|
+
from amsdal_ml.ml_ingesting.types import IngestionSource
|
|
10
|
+
from amsdal_ml.ml_ingesting.types import LoadedDocument
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FolderLoader:
|
|
14
|
+
"""Generic folder loader that delegates file parsing to a Loader."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, *, loader: Loader) -> None:
|
|
17
|
+
self.loader = loader
|
|
18
|
+
|
|
19
|
+
def _iter_paths(self, folder: Path) -> Iterator[Path]:
|
|
20
|
+
for path in folder.rglob('*'):
|
|
21
|
+
if path.is_file() and self._accepts(path):
|
|
22
|
+
yield path
|
|
23
|
+
|
|
24
|
+
def _accepts(self, _path: Path) -> bool:
|
|
25
|
+
return True
|
|
26
|
+
|
|
27
|
+
def _load_path(self, path: Path, *, source: IngestionSource | None) -> LoadedDocument:
|
|
28
|
+
with path.open('rb') as f:
|
|
29
|
+
doc = self.loader.load(f, filename=path.name, metadata=(source.metadata if source else None))
|
|
30
|
+
doc.metadata.setdefault('filename', path.name)
|
|
31
|
+
doc.metadata.setdefault('path', str(path))
|
|
32
|
+
return doc
|
|
33
|
+
|
|
34
|
+
def load_all(self, folder: str | Path, *, source: IngestionSource | None = None) -> list[LoadedDocument]:
|
|
35
|
+
root = Path(folder)
|
|
36
|
+
docs: list[LoadedDocument] = []
|
|
37
|
+
for path in self._iter_paths(root):
|
|
38
|
+
docs.append(self._load_path(path, source=source))
|
|
39
|
+
return docs
|
|
40
|
+
|
|
41
|
+
async def aload_all(self, folder: str | Path, *, source: IngestionSource | None = None) -> list[LoadedDocument]:
|
|
42
|
+
root = Path(folder)
|
|
43
|
+
tasks = [asyncio.to_thread(self._load_path, path, source=source) for path in self._iter_paths(root)]
|
|
44
|
+
return await asyncio.gather(*tasks)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class PdfFolderLoader(FolderLoader):
|
|
48
|
+
def __init__(self, *, pdf_loader: Loader | None = None) -> None:
|
|
49
|
+
super().__init__(loader=pdf_loader or PdfLoader())
|
|
50
|
+
|
|
51
|
+
def _accepts(self, path: Path) -> bool:
|
|
52
|
+
return path.suffix.lower() == '.pdf'
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import IO
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from amsdal_ml.ml_ingesting.types import LoadedDocument
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Loader(ABC):
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def load(
|
|
14
|
+
self,
|
|
15
|
+
file: IO[Any],
|
|
16
|
+
*,
|
|
17
|
+
filename: str | None = None,
|
|
18
|
+
metadata: dict[str, Any] | None = None,
|
|
19
|
+
) -> LoadedDocument: ...
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def aload(
|
|
23
|
+
self,
|
|
24
|
+
file: IO[Any],
|
|
25
|
+
*,
|
|
26
|
+
filename: str | None = None,
|
|
27
|
+
metadata: dict[str, Any] | None = None,
|
|
28
|
+
) -> LoadedDocument: ...
|