amsdal_ml 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. amsdal_ml/Third-Party Materials - AMSDAL Dependencies - License Notices.md +617 -0
  2. amsdal_ml/__about__.py +1 -1
  3. amsdal_ml/agents/__init__.py +13 -0
  4. amsdal_ml/agents/agent.py +5 -7
  5. amsdal_ml/agents/default_qa_agent.py +108 -143
  6. amsdal_ml/agents/functional_calling_agent.py +233 -0
  7. amsdal_ml/agents/mcp_client_tool.py +46 -0
  8. amsdal_ml/agents/python_tool.py +86 -0
  9. amsdal_ml/agents/retriever_tool.py +5 -6
  10. amsdal_ml/agents/tool_adapters.py +98 -0
  11. amsdal_ml/fileio/base_loader.py +7 -5
  12. amsdal_ml/fileio/openai_loader.py +16 -17
  13. amsdal_ml/mcp_client/base.py +2 -0
  14. amsdal_ml/mcp_client/http_client.py +7 -1
  15. amsdal_ml/mcp_client/stdio_client.py +19 -16
  16. amsdal_ml/mcp_server/server_retriever_stdio.py +8 -11
  17. amsdal_ml/ml_ingesting/__init__.py +29 -0
  18. amsdal_ml/ml_ingesting/default_ingesting.py +49 -51
  19. amsdal_ml/ml_ingesting/embedders/__init__.py +4 -0
  20. amsdal_ml/ml_ingesting/embedders/embedder.py +12 -0
  21. amsdal_ml/ml_ingesting/embedders/openai_embedder.py +30 -0
  22. amsdal_ml/ml_ingesting/embedding_data.py +3 -0
  23. amsdal_ml/ml_ingesting/loaders/__init__.py +6 -0
  24. amsdal_ml/ml_ingesting/loaders/folder_loader.py +52 -0
  25. amsdal_ml/ml_ingesting/loaders/loader.py +28 -0
  26. amsdal_ml/ml_ingesting/loaders/pdf_loader.py +136 -0
  27. amsdal_ml/ml_ingesting/loaders/text_loader.py +44 -0
  28. amsdal_ml/ml_ingesting/model_ingester.py +278 -0
  29. amsdal_ml/ml_ingesting/pipeline.py +131 -0
  30. amsdal_ml/ml_ingesting/pipeline_interface.py +31 -0
  31. amsdal_ml/ml_ingesting/processors/__init__.py +4 -0
  32. amsdal_ml/ml_ingesting/processors/cleaner.py +14 -0
  33. amsdal_ml/ml_ingesting/processors/text_cleaner.py +42 -0
  34. amsdal_ml/ml_ingesting/splitters/__init__.py +4 -0
  35. amsdal_ml/ml_ingesting/splitters/splitter.py +15 -0
  36. amsdal_ml/ml_ingesting/splitters/token_splitter.py +85 -0
  37. amsdal_ml/ml_ingesting/stores/__init__.py +4 -0
  38. amsdal_ml/ml_ingesting/stores/embedding_data.py +63 -0
  39. amsdal_ml/ml_ingesting/stores/store.py +22 -0
  40. amsdal_ml/ml_ingesting/types.py +40 -0
  41. amsdal_ml/ml_models/models.py +96 -4
  42. amsdal_ml/ml_models/openai_model.py +430 -122
  43. amsdal_ml/ml_models/utils.py +7 -0
  44. amsdal_ml/ml_retrievers/__init__.py +17 -0
  45. amsdal_ml/ml_retrievers/adapters.py +93 -0
  46. amsdal_ml/ml_retrievers/default_retriever.py +11 -1
  47. amsdal_ml/ml_retrievers/openai_retriever.py +27 -7
  48. amsdal_ml/ml_retrievers/query_retriever.py +487 -0
  49. amsdal_ml/ml_retrievers/retriever.py +12 -0
  50. amsdal_ml/models/embedding_model.py +7 -7
  51. amsdal_ml/prompts/__init__.py +77 -0
  52. amsdal_ml/prompts/database_query_agent.prompt +14 -0
  53. amsdal_ml/prompts/functional_calling_agent_base.prompt +9 -0
  54. amsdal_ml/prompts/nl_query_filter.prompt +318 -0
  55. amsdal_ml/{agents/promts → prompts}/react_chat.prompt +17 -8
  56. amsdal_ml/utils/__init__.py +5 -0
  57. amsdal_ml/utils/query_utils.py +189 -0
  58. {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.0.dist-info}/METADATA +59 -1
  59. amsdal_ml-0.2.0.dist-info/RECORD +72 -0
  60. {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.0.dist-info}/WHEEL +1 -1
  61. amsdal_ml/agents/promts/__init__.py +0 -58
  62. amsdal_ml-0.1.4.dist-info/RECORD +0 -39
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import base64
5
+ import logging
5
6
  import os
6
7
  from collections.abc import Iterable
7
8
  from contextlib import AsyncExitStack
@@ -16,6 +17,8 @@ from mcp.shared.exceptions import McpError
16
17
  from amsdal_ml.mcp_client.base import ToolClient
17
18
  from amsdal_ml.mcp_client.base import ToolInfo
18
19
 
20
+ logger = logging.getLogger(__name__)
21
+
19
22
 
20
23
  class StdioClient(ToolClient):
21
24
  """
@@ -23,22 +26,22 @@ class StdioClient(ToolClient):
23
26
  """
24
27
 
25
28
  def __init__(
26
- self,
27
- alias: str,
28
- module_or_cmd: str,
29
- *args: str,
30
- persist_session: bool = True,
31
- send_amsdal_config: bool = True,
29
+ self,
30
+ alias: str,
31
+ module_or_cmd: str,
32
+ *args: str,
33
+ persist_session: bool = True,
34
+ send_amsdal_config: bool = True,
32
35
  ):
33
36
  self.alias = alias
34
- if module_or_cmd in ("python", "python3"):
37
+ if module_or_cmd in ('python', 'python3'):
35
38
  self._command = module_or_cmd
36
39
  self._args = list(args)
37
40
  else:
38
- self._command = "python"
39
- self._args = ["-m", module_or_cmd]
41
+ self._command = 'python'
42
+ self._args = ['-m', module_or_cmd]
40
43
  if send_amsdal_config:
41
- self._args.append("--amsdal-config")
44
+ self._args.append('--amsdal-config')
42
45
  self._args.append(self._build_amsdal_config_arg())
43
46
  self._persist = persist_session
44
47
  self._lock = asyncio.Lock()
@@ -82,8 +85,8 @@ class StdioClient(ToolClient):
82
85
  ToolInfo(
83
86
  alias=alias,
84
87
  name=t.name,
85
- description=(getattr(t, "description", None) or ""),
86
- input_schema=(getattr(t, "inputSchema", None) or {}),
88
+ description=(getattr(t, 'description', None) or ''),
89
+ input_schema=(getattr(t, 'inputSchema', None) or {}),
87
90
  )
88
91
  for t in resp_tools
89
92
  ]
@@ -127,20 +130,20 @@ class StdioClient(ToolClient):
127
130
  rx, tx = await stack.enter_async_context(stdio_client(params))
128
131
  s = await stack.enter_async_context(ClientSession(rx, tx))
129
132
  await s.initialize()
130
- print("Calling tool:", tool_name, "with args:", args) # noqa: T201
133
+ logger.debug("Calling tool: %s with args: %s", tool_name, args)
131
134
  res = await self._call_with_timeout(s.call_tool(tool_name, args), timeout=timeout)
132
- return getattr(res, "content", res)
135
+ return getattr(res, 'content', res)
133
136
 
134
137
  # Persistent session path
135
138
  s = await self._ensure_session()
136
139
  try:
137
140
  res = await self._call_with_timeout(s.call_tool(tool_name, args), timeout=timeout)
138
- return getattr(res, "content", res)
141
+ return getattr(res, 'content', res)
139
142
  except (TimeoutError, McpError):
140
143
  await self._reset_session()
141
144
  s = await self._ensure_session()
142
145
  res = await self._call_with_timeout(s.call_tool(tool_name, args), timeout=timeout)
143
- return getattr(res, "content", res)
146
+ return getattr(res, 'content', res)
144
147
 
145
148
  def _build_amsdal_config_arg(self) -> str:
146
149
  """
@@ -18,32 +18,29 @@ from amsdal_ml.agents.retriever_tool import retriever_search
18
18
  logging.basicConfig(
19
19
  level=logging.INFO,
20
20
  format='%(asctime)s [%(levelname)s] %(message)s',
21
- handlers=[
22
- logging.FileHandler("server.log"),
23
- logging.StreamHandler(sys.stdout)
24
- ]
21
+ handlers=[logging.FileHandler('server.log'), logging.StreamHandler(sys.stdout)],
25
22
  )
26
23
 
27
24
  parser = argparse.ArgumentParser()
28
- parser.add_argument("--amsdal-config", required=False, help="Base64-encoded config string")
25
+ parser.add_argument('--amsdal-config', required=False, help='Base64-encoded config string')
29
26
  args = parser.parse_args()
30
27
 
31
- logging.info(f"Starting server with args: {args}")
28
+ logging.info(f'Starting server with args: {args}')
32
29
 
33
30
  if args.amsdal_config:
34
- decoded = base64.b64decode(args.amsdal_config).decode("utf-8")
31
+ decoded = base64.b64decode(args.amsdal_config).decode('utf-8')
35
32
  amsdal_config = AmsdalConfig(**json.loads(decoded))
36
- logging.info(f"Loaded Amsdal config: {amsdal_config}")
33
+ logging.info(f'Loaded Amsdal config: {amsdal_config}')
37
34
  AmsdalConfigManager().set_config(amsdal_config)
38
35
 
39
36
  manager: Any
40
37
  if amsdal_config.async_mode:
41
38
  manager = AsyncAmsdalManager()
42
- logging.info("pre-setup")
39
+ logging.info('pre-setup')
43
40
  asyncio.run(cast(Any, manager).setup())
44
- logging.info("post-setup")
41
+ logging.info('post-setup')
45
42
  asyncio.run(cast(Any, manager).post_setup())
46
- logging.info("manager inited")
43
+ logging.info('manager inited')
47
44
  else:
48
45
  manager = AmsdalManager()
49
46
  cast(Any, manager).setup()
@@ -0,0 +1,29 @@
1
+ from amsdal_ml.ml_ingesting.embedders.embedder import Embedder
2
+ from amsdal_ml.ml_ingesting.loaders.loader import Loader
3
+ from amsdal_ml.ml_ingesting.loaders.text_loader import TextLoader
4
+ from amsdal_ml.ml_ingesting.model_ingester import ModelIngester
5
+ from amsdal_ml.ml_ingesting.pipeline import DefaultIngestionPipeline
6
+ from amsdal_ml.ml_ingesting.pipeline_interface import IngestionPipeline
7
+ from amsdal_ml.ml_ingesting.processors.cleaner import Cleaner
8
+ from amsdal_ml.ml_ingesting.splitters.splitter import Splitter
9
+ from amsdal_ml.ml_ingesting.stores.store import EmbeddingStore
10
+ from amsdal_ml.ml_ingesting.types import IngestionSource
11
+ from amsdal_ml.ml_ingesting.types import LoadedDocument
12
+ from amsdal_ml.ml_ingesting.types import LoadedPage
13
+ from amsdal_ml.ml_ingesting.types import TextChunk
14
+
15
+ __all__ = [
16
+ 'Cleaner',
17
+ 'DefaultIngestionPipeline',
18
+ 'Embedder',
19
+ 'EmbeddingStore',
20
+ 'IngestionPipeline',
21
+ 'IngestionSource',
22
+ 'LoadedDocument',
23
+ 'LoadedPage',
24
+ 'Loader',
25
+ 'ModelIngester',
26
+ 'Splitter',
27
+ 'TextChunk',
28
+ 'TextLoader',
29
+ ]
@@ -25,7 +25,7 @@ _MIN_WORDS_PER_SENT = 4
25
25
 
26
26
  class DepthLimitReached(str):
27
27
  def __str__(self) -> str:
28
- return "Truncated due to reached depth limit"
28
+ return 'Truncated due to reached depth limit'
29
29
 
30
30
 
31
31
  @dataclass
@@ -33,17 +33,17 @@ class VisitedObject:
33
33
  obj: Any
34
34
 
35
35
  def __str__(self) -> str:
36
- return f"Recursion reference to object {self.obj}"
36
+ return f'Recursion reference to object {self.obj}'
37
37
 
38
38
 
39
39
  class MissingRelation(str):
40
40
  def __str__(self) -> str:
41
- return "Relation not present"
41
+ return 'Relation not present'
42
42
 
43
43
 
44
44
  class NoChildren(str):
45
45
  def __str__(self) -> str:
46
- return "No nested data"
46
+ return 'No nested data'
47
47
 
48
48
 
49
49
  # UP007: use X | Y style
@@ -79,109 +79,109 @@ class DefaultIngesting(MLIngesting):
79
79
  self._afacts_transform = afacts_transform
80
80
 
81
81
  def _default_header(self, instance: Any, facts: list[str]) -> str:
82
- doc = getattr(instance.__class__, "__doc__", "") or f"Instance of {instance.__class__.__name__}"
83
- return (doc.strip() + "\n\nKey facts:\n" + "\n".join(facts)).strip()
82
+ doc = getattr(instance.__class__, '__doc__', '') or f'Instance of {instance.__class__.__name__}'
83
+ return (doc.strip() + '\n\nKey facts:\n' + '\n'.join(facts)).strip()
84
84
 
85
85
  def _walk_sync(self, obj: Any, depth: int, visited: set[tuple[str, str]]) -> list[str | Marker]:
86
86
  if depth > self.max_depth:
87
- return [DepthLimitReached("")]
88
- key = (obj.__class__.__name__, str(getattr(obj, "object_id", id(obj))))
87
+ return [DepthLimitReached('')]
88
+ key = (obj.__class__.__name__, str(getattr(obj, 'object_id', id(obj))))
89
89
  if key in visited:
90
90
  return [VisitedObject(obj)]
91
91
  visited.add(key)
92
92
 
93
93
  out: list[str | Marker] = []
94
- fields = getattr(obj.__class__, "model_fields", {})
95
- for name, field in getattr(fields, "items", lambda: [])():
94
+ fields = getattr(obj.__class__, 'model_fields', {})
95
+ for name, field in getattr(fields, 'items', lambda: [])():
96
96
  try:
97
97
  v = getattr(obj, name)
98
- title = getattr(field, "title", None) or name.replace("_", " ").capitalize()
98
+ title = getattr(field, 'title', None) or name.replace('_', ' ').capitalize()
99
99
  if v is None:
100
100
  continue
101
101
  if isinstance(v, str | int | float | bool | date):
102
- out.append(f"{title}: {v}")
103
- elif hasattr(v.__class__, "model_fields"):
102
+ out.append(f'{title}: {v}')
103
+ elif hasattr(v.__class__, 'model_fields'):
104
104
  sub = self._walk_sync(v, depth + 1, visited)
105
- out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren("")))
105
+ out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren('')))
106
106
  elif isinstance(v, list):
107
107
  simple = [str(x) for x in v if isinstance(x, str | int | float)]
108
108
  if simple:
109
109
  out.append(f'{title}: {", ".join(simple)}')
110
110
  except Exception as e: # noqa: BLE001
111
- logger.warning(f"[walk_sync] field {name}: {e}")
111
+ logger.warning(f'[walk_sync] field {name}: {e}')
112
112
 
113
- fks = getattr(obj.__class__, "FOREIGN_KEYS", [])
113
+ fks = getattr(obj.__class__, 'FOREIGN_KEYS', [])
114
114
  if not fks and not out:
115
- out.append(NoChildren(""))
115
+ out.append(NoChildren(''))
116
116
  for fk in fks:
117
117
  try:
118
118
  rel = getattr(obj, fk, None)
119
119
  if rel is None:
120
- out.append(MissingRelation(""))
120
+ out.append(MissingRelation(''))
121
121
  continue
122
122
  if isinstance(rel, list):
123
123
  for i, item in enumerate(rel):
124
- if hasattr(item.__class__, "model_fields"):
124
+ if hasattr(item.__class__, 'model_fields'):
125
125
  sub = self._walk_sync(item, depth + 1, visited)
126
126
  out.append(f'{fk}[{i}] → {"; ".join(map(str, sub))}')
127
- elif hasattr(rel.__class__, "model_fields"):
127
+ elif hasattr(rel.__class__, 'model_fields'):
128
128
  sub = self._walk_sync(rel, depth + 1, visited)
129
129
  out.append(f'{fk} → {"; ".join(map(str, sub))}')
130
130
  except Exception as e: # noqa: BLE001
131
- logger.warning(f"[walk_sync] FK {fk}: {e}")
131
+ logger.warning(f'[walk_sync] FK {fk}: {e}')
132
132
  return out
133
133
 
134
134
  async def _walk_async(self, obj: Any, depth: int, visited: set[tuple[str, str]]) -> list[str | Marker]:
135
135
  if depth > self.max_depth:
136
- return [DepthLimitReached("")]
137
- key = (obj.__class__.__name__, str(getattr(obj, "object_id", id(obj))))
136
+ return [DepthLimitReached('')]
137
+ key = (obj.__class__.__name__, str(getattr(obj, 'object_id', id(obj))))
138
138
  if key in visited:
139
139
  return [VisitedObject(obj)]
140
140
  visited.add(key)
141
141
 
142
142
  out: list[str | Marker] = []
143
- fields = getattr(obj.__class__, "model_fields", {})
144
- for name, field in getattr(fields, "items", lambda: [])():
143
+ fields = getattr(obj.__class__, 'model_fields', {})
144
+ for name, field in getattr(fields, 'items', lambda: [])():
145
145
  try:
146
146
  v = getattr(obj, name)
147
147
  if asyncio.iscoroutine(v):
148
148
  v = await v
149
- title = getattr(field, "title", None) or name.replace("_", " ").capitalize()
149
+ title = getattr(field, 'title', None) or name.replace('_', ' ').capitalize()
150
150
  if v is None:
151
151
  continue
152
152
  if isinstance(v, str | int | float | bool | date):
153
- out.append(f"{title}: {v}")
154
- elif hasattr(v.__class__, "model_fields"):
153
+ out.append(f'{title}: {v}')
154
+ elif hasattr(v.__class__, 'model_fields'):
155
155
  sub = await self._walk_async(v, depth + 1, visited)
156
- out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren("")))
156
+ out.append(f'{title} → {"; ".join(map(str, sub))}' if sub else str(NoChildren('')))
157
157
  elif isinstance(v, list):
158
158
  simple = [str(x) for x in v if isinstance(x, str | int | float)]
159
159
  if simple:
160
160
  out.append(f'{title}: {", ".join(simple)}')
161
161
  except Exception as e: # noqa: BLE001
162
- logger.warning(f"[walk_async] field {name}: {e}")
162
+ logger.warning(f'[walk_async] field {name}: {e}')
163
163
 
164
- fks = getattr(obj.__class__, "FOREIGN_KEYS", [])
164
+ fks = getattr(obj.__class__, 'FOREIGN_KEYS', [])
165
165
  if not fks and not out:
166
- out.append(NoChildren(""))
166
+ out.append(NoChildren(''))
167
167
  for fk in fks:
168
168
  try:
169
169
  rel = getattr(obj, fk, None)
170
170
  if asyncio.iscoroutine(rel):
171
171
  rel = await rel
172
172
  if rel is None:
173
- out.append(MissingRelation(""))
173
+ out.append(MissingRelation(''))
174
174
  continue
175
175
  if isinstance(rel, list):
176
176
  for i, item in enumerate(rel):
177
- if hasattr(item.__class__, "model_fields"):
177
+ if hasattr(item.__class__, 'model_fields'):
178
178
  sub = await self._walk_async(item, depth + 1, visited)
179
179
  out.append(f'{fk}[{i}] → {"; ".join(map(str, sub))}')
180
- elif hasattr(rel.__class__, "model_fields"):
180
+ elif hasattr(rel.__class__, 'model_fields'):
181
181
  sub = await self._walk_async(rel, depth + 1, visited)
182
182
  out.append(f'{fk} → {"; ".join(map(str, sub))}')
183
183
  except Exception as e: # noqa: BLE001
184
- logger.warning(f"[walk_async] FK {fk}: {e}")
184
+ logger.warning(f'[walk_async] FK {fk}: {e}')
185
185
  return out
186
186
 
187
187
  def collect_facts(self, instance: Any) -> list[str | Marker]:
@@ -221,38 +221,36 @@ class DefaultIngesting(MLIngesting):
221
221
  return list(self._tags)
222
222
 
223
223
  def _split(self, text: str, max_sentences: int = 7) -> list[str]:
224
- sents = re.split(r"(?<=[.!?])\s+", text.strip())
224
+ sents = re.split(r'(?<=[.!?])\s+', text.strip())
225
225
  sents = [s.strip() for s in sents if len(s.split()) >= _MIN_WORDS_PER_SENT]
226
226
  chunks: list[str] = []
227
227
  cur: list[str] = []
228
228
  for s in sents:
229
- proposal = (" ".join([*cur, s])).strip()
229
+ proposal = (' '.join([*cur, s])).strip()
230
230
  if self._token_len_fn(proposal) <= self.max_tokens_per_chunk and len(cur) < max_sentences:
231
231
  cur.append(s)
232
232
  else:
233
233
  if cur:
234
- ch = " ".join(cur).strip()
235
- if ch and not ch.endswith("."):
236
- ch += "."
234
+ ch = ' '.join(cur).strip()
235
+ if ch and not ch.endswith('.'):
236
+ ch += '.'
237
237
  chunks.append(ch)
238
238
  cur = [s]
239
239
  if cur:
240
- ch = " ".join(cur).strip()
241
- if ch and not ch.endswith("."):
242
- ch += "."
240
+ ch = ' '.join(cur).strip()
241
+ if ch and not ch.endswith('.'):
242
+ ch += '.'
243
243
  chunks.append(ch)
244
244
  return chunks
245
245
 
246
246
  def _resolve_link(self, instance: Any) -> tuple[str, str]:
247
247
  cls = instance.__class__.__name__
248
- oid = getattr(instance, "object_id", None)
248
+ oid = getattr(instance, 'object_id', None)
249
249
  if oid is None:
250
- oid = str(getattr(instance, "id", None) or id(instance))
250
+ oid = str(getattr(instance, 'id', None) or id(instance))
251
251
  return cls, str(oid)
252
252
 
253
- def _make_records(
254
- self, chunks: list[str], vectors: list[list[float]], tags: list[str]
255
- ) -> list[EmbeddingData]:
253
+ def _make_records(self, chunks: list[str], vectors: list[list[float]], tags: list[str]) -> list[EmbeddingData]:
256
254
  out: list[EmbeddingData] = []
257
255
  for i, (t, v) in enumerate(zip(chunks[: self.max_chunks], vectors, strict=False)):
258
256
  out.append(EmbeddingData(chunk_index=i, raw_text=t, embedding=v, tags=tags))
@@ -262,7 +260,7 @@ class DefaultIngesting(MLIngesting):
262
260
  self, instance: Any, embed_func: Callable[[str], list[float]] | None = None
263
261
  ) -> list[EmbeddingData]:
264
262
  if embed_func is None:
265
- msg = "embed_func is required for DefaultIngesting.generate_embeddings"
263
+ msg = 'embed_func is required for DefaultIngesting.generate_embeddings'
266
264
  raise RuntimeError(msg)
267
265
  text = self.generate_text(instance)
268
266
  chunks = self._split(text)
@@ -274,7 +272,7 @@ class DefaultIngesting(MLIngesting):
274
272
  self, instance: Any, embed_func: Callable[[str], Awaitable[list[float]]] | None = None
275
273
  ) -> list[EmbeddingData]:
276
274
  if embed_func is None:
277
- msg = "embed_func is required for DefaultIngesting.agenerate_embeddings"
275
+ msg = 'embed_func is required for DefaultIngesting.agenerate_embeddings'
278
276
  raise RuntimeError(msg)
279
277
  text = await self.agenerate_text(instance)
280
278
  chunks = self._split(text)
@@ -0,0 +1,4 @@
1
+ from amsdal_ml.ml_ingesting.embedders.embedder import Embedder
2
+ from amsdal_ml.ml_ingesting.embedders.openai_embedder import OpenAIEmbedder
3
+
4
+ __all__ = ['Embedder', 'OpenAIEmbedder']
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from abc import abstractmethod
5
+
6
+
7
+ class Embedder(ABC):
8
+ @abstractmethod
9
+ def embed(self, text: str) -> list[float]: ...
10
+
11
+ @abstractmethod
12
+ async def aembed(self, text: str) -> list[float]: ...
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ from openai import AsyncOpenAI
6
+ from openai import OpenAI
7
+
8
+ from amsdal_ml.ml_config import ml_config
9
+ from amsdal_ml.ml_ingesting.embedders.embedder import Embedder
10
+
11
+ DEFAULT_EMBED_MODEL = ml_config.embed_model_name
12
+
13
+
14
+ class OpenAIEmbedder(Embedder):
15
+ def __init__(self, *, api_key: str | None = None, embed_model: str | None = None) -> None:
16
+ self.api_key = api_key or ml_config.resolved_openai_key or os.getenv('OPENAI_API_KEY')
17
+ if not self.api_key:
18
+ msg = 'OPENAI_API_KEY is required for OpenAIEmbedder'
19
+ raise RuntimeError(msg)
20
+ self.embed_model = embed_model or DEFAULT_EMBED_MODEL
21
+ self.client = OpenAI(api_key=self.api_key)
22
+ self.aclient = AsyncOpenAI(api_key=self.api_key)
23
+
24
+ def embed(self, text: str) -> list[float]:
25
+ resp = self.client.embeddings.create(model=self.embed_model, input=text)
26
+ return resp.data[0].embedding
27
+
28
+ async def aembed(self, text: str) -> list[float]:
29
+ resp = await self.aclient.embeddings.create(model=self.embed_model, input=text)
30
+ return resp.data[0].embedding
@@ -1,3 +1,5 @@
1
+ from typing import Any
2
+
1
3
  from pydantic import BaseModel
2
4
  from pydantic import Field
3
5
 
@@ -7,3 +9,4 @@ class EmbeddingData(BaseModel):
7
9
  raw_text: str = Field(..., title='Raw text used for embedding')
8
10
  embedding: list[float] = Field(..., title='Vector embedding')
9
11
  tags: list[str] = Field(default_factory=list, title='Embedding tags')
12
+ metadata: dict[str, Any] = Field(default_factory=dict, title='Embedding metadata')
@@ -0,0 +1,6 @@
1
+ from amsdal_ml.ml_ingesting.loaders.folder_loader import FolderLoader
2
+ from amsdal_ml.ml_ingesting.loaders.folder_loader import PdfFolderLoader
3
+ from amsdal_ml.ml_ingesting.loaders.loader import Loader
4
+ from amsdal_ml.ml_ingesting.loaders.pdf_loader import PdfLoader
5
+
6
+ __all__ = ['FolderLoader', 'Loader', 'PdfFolderLoader', 'PdfLoader']
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import Iterator
5
+ from pathlib import Path
6
+
7
+ from amsdal_ml.ml_ingesting.loaders.loader import Loader
8
+ from amsdal_ml.ml_ingesting.loaders.pdf_loader import PdfLoader
9
+ from amsdal_ml.ml_ingesting.types import IngestionSource
10
+ from amsdal_ml.ml_ingesting.types import LoadedDocument
11
+
12
+
13
+ class FolderLoader:
14
+ """Generic folder loader that delegates file parsing to a Loader."""
15
+
16
+ def __init__(self, *, loader: Loader) -> None:
17
+ self.loader = loader
18
+
19
+ def _iter_paths(self, folder: Path) -> Iterator[Path]:
20
+ for path in folder.rglob('*'):
21
+ if path.is_file() and self._accepts(path):
22
+ yield path
23
+
24
+ def _accepts(self, _path: Path) -> bool:
25
+ return True
26
+
27
+ def _load_path(self, path: Path, *, source: IngestionSource | None) -> LoadedDocument:
28
+ with path.open('rb') as f:
29
+ doc = self.loader.load(f, filename=path.name, metadata=(source.metadata if source else None))
30
+ doc.metadata.setdefault('filename', path.name)
31
+ doc.metadata.setdefault('path', str(path))
32
+ return doc
33
+
34
+ def load_all(self, folder: str | Path, *, source: IngestionSource | None = None) -> list[LoadedDocument]:
35
+ root = Path(folder)
36
+ docs: list[LoadedDocument] = []
37
+ for path in self._iter_paths(root):
38
+ docs.append(self._load_path(path, source=source))
39
+ return docs
40
+
41
+ async def aload_all(self, folder: str | Path, *, source: IngestionSource | None = None) -> list[LoadedDocument]:
42
+ root = Path(folder)
43
+ tasks = [asyncio.to_thread(self._load_path, path, source=source) for path in self._iter_paths(root)]
44
+ return await asyncio.gather(*tasks)
45
+
46
+
47
+ class PdfFolderLoader(FolderLoader):
48
+ def __init__(self, *, pdf_loader: Loader | None = None) -> None:
49
+ super().__init__(loader=pdf_loader or PdfLoader())
50
+
51
+ def _accepts(self, path: Path) -> bool:
52
+ return path.suffix.lower() == '.pdf'
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from abc import abstractmethod
5
+ from typing import IO
6
+ from typing import Any
7
+
8
+ from amsdal_ml.ml_ingesting.types import LoadedDocument
9
+
10
+
11
+ class Loader(ABC):
12
+ @abstractmethod
13
+ def load(
14
+ self,
15
+ file: IO[Any],
16
+ *,
17
+ filename: str | None = None,
18
+ metadata: dict[str, Any] | None = None,
19
+ ) -> LoadedDocument: ...
20
+
21
+ @abstractmethod
22
+ async def aload(
23
+ self,
24
+ file: IO[Any],
25
+ *,
26
+ filename: str | None = None,
27
+ metadata: dict[str, Any] | None = None,
28
+ ) -> LoadedDocument: ...