amsdal_ml 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. amsdal_ml/Third-Party Materials - AMSDAL Dependencies - License Notices.md +617 -0
  2. amsdal_ml/__about__.py +1 -1
  3. amsdal_ml/agents/__init__.py +13 -0
  4. amsdal_ml/agents/agent.py +5 -7
  5. amsdal_ml/agents/default_qa_agent.py +108 -143
  6. amsdal_ml/agents/functional_calling_agent.py +233 -0
  7. amsdal_ml/agents/mcp_client_tool.py +46 -0
  8. amsdal_ml/agents/python_tool.py +86 -0
  9. amsdal_ml/agents/retriever_tool.py +17 -8
  10. amsdal_ml/agents/tool_adapters.py +98 -0
  11. amsdal_ml/fileio/base_loader.py +7 -5
  12. amsdal_ml/fileio/openai_loader.py +16 -17
  13. amsdal_ml/mcp_client/base.py +2 -0
  14. amsdal_ml/mcp_client/http_client.py +7 -1
  15. amsdal_ml/mcp_client/stdio_client.py +21 -18
  16. amsdal_ml/mcp_server/server_retriever_stdio.py +8 -11
  17. amsdal_ml/ml_ingesting/__init__.py +29 -0
  18. amsdal_ml/ml_ingesting/default_ingesting.py +49 -51
  19. amsdal_ml/ml_ingesting/embedders/__init__.py +4 -0
  20. amsdal_ml/ml_ingesting/embedders/embedder.py +12 -0
  21. amsdal_ml/ml_ingesting/embedders/openai_embedder.py +30 -0
  22. amsdal_ml/ml_ingesting/embedding_data.py +3 -0
  23. amsdal_ml/ml_ingesting/loaders/__init__.py +6 -0
  24. amsdal_ml/ml_ingesting/loaders/folder_loader.py +52 -0
  25. amsdal_ml/ml_ingesting/loaders/loader.py +28 -0
  26. amsdal_ml/ml_ingesting/loaders/pdf_loader.py +136 -0
  27. amsdal_ml/ml_ingesting/loaders/text_loader.py +44 -0
  28. amsdal_ml/ml_ingesting/model_ingester.py +278 -0
  29. amsdal_ml/ml_ingesting/pipeline.py +131 -0
  30. amsdal_ml/ml_ingesting/pipeline_interface.py +31 -0
  31. amsdal_ml/ml_ingesting/processors/__init__.py +4 -0
  32. amsdal_ml/ml_ingesting/processors/cleaner.py +14 -0
  33. amsdal_ml/ml_ingesting/processors/text_cleaner.py +42 -0
  34. amsdal_ml/ml_ingesting/splitters/__init__.py +4 -0
  35. amsdal_ml/ml_ingesting/splitters/splitter.py +15 -0
  36. amsdal_ml/ml_ingesting/splitters/token_splitter.py +85 -0
  37. amsdal_ml/ml_ingesting/stores/__init__.py +4 -0
  38. amsdal_ml/ml_ingesting/stores/embedding_data.py +63 -0
  39. amsdal_ml/ml_ingesting/stores/store.py +22 -0
  40. amsdal_ml/ml_ingesting/types.py +40 -0
  41. amsdal_ml/ml_models/models.py +96 -4
  42. amsdal_ml/ml_models/openai_model.py +430 -122
  43. amsdal_ml/ml_models/utils.py +7 -0
  44. amsdal_ml/ml_retrievers/__init__.py +17 -0
  45. amsdal_ml/ml_retrievers/adapters.py +93 -0
  46. amsdal_ml/ml_retrievers/default_retriever.py +11 -1
  47. amsdal_ml/ml_retrievers/openai_retriever.py +27 -7
  48. amsdal_ml/ml_retrievers/query_retriever.py +487 -0
  49. amsdal_ml/ml_retrievers/retriever.py +12 -0
  50. amsdal_ml/models/embedding_model.py +7 -7
  51. amsdal_ml/prompts/__init__.py +77 -0
  52. amsdal_ml/prompts/database_query_agent.prompt +14 -0
  53. amsdal_ml/prompts/functional_calling_agent_base.prompt +9 -0
  54. amsdal_ml/prompts/nl_query_filter.prompt +318 -0
  55. amsdal_ml/{agents/promts → prompts}/react_chat.prompt +17 -8
  56. amsdal_ml/utils/__init__.py +5 -0
  57. amsdal_ml/utils/query_utils.py +189 -0
  58. amsdal_ml-0.2.0.dist-info/METADATA +293 -0
  59. amsdal_ml-0.2.0.dist-info/RECORD +72 -0
  60. {amsdal_ml-0.1.3.dist-info → amsdal_ml-0.2.0.dist-info}/WHEEL +1 -1
  61. amsdal_ml/agents/promts/__init__.py +0 -58
  62. amsdal_ml-0.1.3.dist-info/METADATA +0 -69
  63. amsdal_ml-0.1.3.dist-info/RECORD +0 -39
@@ -0,0 +1,487 @@
1
+ """Natural language query interface for Amsdal QuerySets."""
2
+
3
+ import json
4
+ from datetime import date
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import Any
9
+ from typing import Generic
10
+ from typing import Literal
11
+ from typing import Optional
12
+ from typing import TypeVar
13
+ from typing import get_args
14
+ from typing import get_origin
15
+
16
+ from amsdal_models.classes.model import Model
17
+ from amsdal_models.querysets.base_queryset import QuerySetBase
18
+ from amsdal_utils.query.enums import Lookup
19
+ from pydantic import BaseModel
20
+
21
+ from amsdal_ml.ml_models.models import MLModel
22
+ from amsdal_ml.ml_models.utils import ResponseFormat
23
+ from amsdal_ml.ml_retrievers.adapters import DefaultRetrieverAdapter
24
+ from amsdal_ml.ml_retrievers.adapters import get_retriever_adapter
25
+ from amsdal_ml.ml_retrievers.retriever import Document
26
+ from amsdal_ml.ml_retrievers.retriever import Retriever
27
+ from amsdal_ml.prompts import get_prompt
28
+ from amsdal_ml.utils.query_utils import serialize_and_clean_record
29
+
30
+ T = TypeVar("T", bound=Model)
31
+ AmsdalQuerySet = QuerySetBase[T]
32
+
33
+
34
+ class FilterCondition(BaseModel):
35
+ """Single filter condition for database query."""
36
+
37
+ field: str
38
+ lookup: str
39
+ value: str | int | float | bool | None | date | datetime | list[str | int | float | bool | date | datetime]
40
+
41
+
42
+ class FilterResponse(BaseModel):
43
+ """Structured response for LLM."""
44
+
45
+ filters: list[FilterCondition]
46
+
47
+
48
+ class NLQueryExecutor:
49
+ """
50
+ Natural language query interface for Amsdal QuerySets.
51
+
52
+ Converts natural language queries into structured database filters using an LLM.
53
+
54
+ Supported Field Types:
55
+ - Primitives: str, int, float, bool
56
+ - Enums: Enum subclasses
57
+ - Literals: Literal[...]
58
+ - Unions: Union of supported types (no mixed with models/dict/lists/tuples/sets)
59
+ - Annotated: Annotated[supported_type, ...]
60
+ - Dates: datetime, date
61
+
62
+ Unsupported Field Types (will be skipped):
63
+ - Models: BaseModel/Model subclasses
64
+ - Dicts: dict[...]
65
+ - Lists/Tuples/Sets: list/tuple/set[...]
66
+ - Callables: Callable[...]
67
+ - Unknown types: bytes, custom classes, etc.
68
+ Example:
69
+ >>> llm = OpenAIModel()
70
+ >>> llm.setup()
71
+ >>> base_qs = BaseRateInfo.objects.filter(accessibility__eq='public')
72
+ >>> executor = NLQueryExecutor(
73
+ ... llm=llm,
74
+ ... queryset=base_qs,
75
+ ... base_filters={'status__eq': 'active'},
76
+ ... fields=['rate', 'premium_band', 'effective_date', 'rate_category'],
77
+ ... fields_info={'rate': 'Current rate value as decimal (5.15% = 0.0515)'}
78
+ ... )
79
+ >>> results = await executor.execute("show me all high band rates")
80
+ """
81
+
82
+ _full_prompt: Optional[str]
83
+ MAX_VALUES_DISPLAY_LIMIT = 100
84
+ MAX_SCHEMA_INSPECTION_DEPTH = 10
85
+
86
+ def __init__(
87
+ self,
88
+ llm: MLModel,
89
+ queryset: AmsdalQuerySet[T],
90
+ base_filters: Optional[dict[str, Any]] = None,
91
+ fields: Optional[list[str]] = None,
92
+ fields_info: Optional[dict[str, str]] = None,
93
+ prompt_path: Optional[str | Path] = None,
94
+ llm_response_format: Optional[ResponseFormat] = None,
95
+ adapter: Optional[DefaultRetrieverAdapter] = None,
96
+ ) -> None:
97
+ """
98
+ Initialize NLQueryExecutor.
99
+
100
+ Args:
101
+ llm: ML model instance for query interpretation
102
+ queryset: AMSDAL QuerySet
103
+ base_filters: Optional dict of additional base filters to apply to all queries
104
+ (e.g., {'accessibility__eq': 'public'})
105
+ fields: Optional list of field names to include in search schema.
106
+ If None, all model fields are used.
107
+ fields_info: Optional dict mapping field names to descriptions.
108
+ Takes priority over model field titles.
109
+ Format: {'field_name': 'description text'}
110
+ prompt_path: Optional path to custom prompt file. If None, uses default 'nl_query_filter'
111
+ which is specifically designed for AMSDAL ORM with comprehensive coverage of all
112
+ lookup operators, field type conversions, and detailed examples for
113
+ optimal query interpretation.
114
+ llm_response_format: Optional desired response format. Can be ResponseFormat enum or string
115
+ ('JSON_SCHEMA', 'JSON_OBJECT', 'PLAIN_TEXT'). If None, will auto-detect best.
116
+ adapter: Optional adapter instance. If None, will auto-detect based on LLM type.
117
+ """
118
+ self.llm = llm
119
+ self.base_queryset = queryset
120
+ self.model_class = queryset.entity
121
+ self.base_filters = base_filters or {}
122
+ self.fields = fields
123
+ self.fields_info = fields_info or {}
124
+ self.prompt_path = prompt_path
125
+ self.adapter = adapter or get_retriever_adapter(llm)
126
+ self.llm_response_format = self._select_response_format(llm_response_format)
127
+
128
+ self.search_schema = self._build_search_schema()
129
+ self.system_prompt = self._build_system_prompt()
130
+ self._full_prompt = None
131
+
132
+ @property
133
+ def response_schema(self) -> dict[str, Any]:
134
+ """
135
+ Generates a JSON schema for the response.
136
+ """
137
+ base_schema = FilterResponse.model_json_schema()
138
+ return self.adapter.get_response_schema(base_schema)
139
+
140
+ async def execute(self, query: str, limit : int | None = 30) -> list[T]:
141
+ results: list[T] = await self.asearch(query)
142
+ return results[:limit] if limit is not None else results
143
+
144
+ async def asearch(self, query: str) -> list[T]:
145
+ prompt = get_prompt("nl_query_filter", custom_path=self.prompt_path)
146
+ full_prompt = prompt.render_text(
147
+ schema=json.dumps(self.search_schema, indent=2), query=query
148
+ )
149
+ self._full_prompt = full_prompt
150
+
151
+ response = await self.llm.ainvoke(
152
+ full_prompt,
153
+ response_format=self.llm_response_format,
154
+ schema=(
155
+ self.response_schema
156
+ if self.llm_response_format == ResponseFormat.JSON_SCHEMA
157
+ else None
158
+ ),
159
+ )
160
+ raw_json = response.strip()
161
+
162
+ if self.llm_response_format == ResponseFormat.PLAIN_TEXT:
163
+ if "```json" in raw_json:
164
+ raw_json = raw_json.split("```json")[1].split("```")[0].strip()
165
+ elif "```" in raw_json:
166
+ raw_json = raw_json.split("```")[1].split("```")[0].strip()
167
+
168
+ queryset = self.base_queryset.filter(**self.base_filters)
169
+
170
+ conditions = self.adapter.parse_response(
171
+ raw_json,
172
+ is_schema_based=self.llm_response_format == ResponseFormat.JSON_SCHEMA,
173
+ )
174
+
175
+ if not conditions:
176
+ return await queryset.aexecute() # type: ignore[attr-defined]
177
+
178
+ allowed_lookups = {lookup.value for lookup in Lookup}
179
+ all_filters = {}
180
+
181
+ for cond in conditions:
182
+ if cond.lookup not in allowed_lookups:
183
+ continue
184
+
185
+ filter_key = f"{cond.field}__{cond.lookup}"
186
+ all_filters[filter_key] = cond.value
187
+
188
+ queryset = queryset.filter(**all_filters)
189
+
190
+ return await queryset.aexecute() # type: ignore[attr-defined]
191
+
192
+ def search(self, query: str) -> list[T]:
193
+ prompt = get_prompt("nl_query_filter", custom_path=self.prompt_path)
194
+ full_prompt = prompt.render_text(
195
+ schema=json.dumps(self.search_schema, indent=2), query=query
196
+ )
197
+ self._full_prompt = full_prompt
198
+
199
+ response = self.llm.invoke(
200
+ full_prompt,
201
+ response_format=self.llm_response_format,
202
+ schema=(
203
+ self.response_schema
204
+ if self.llm_response_format == ResponseFormat.JSON_SCHEMA
205
+ else None
206
+ ),
207
+ )
208
+ raw_json = response.strip()
209
+
210
+ if self.llm_response_format == ResponseFormat.PLAIN_TEXT:
211
+ if "```json" in raw_json:
212
+ raw_json = raw_json.split("```json")[1].split("```")[0].strip()
213
+ elif "```" in raw_json:
214
+ raw_json = raw_json.split("```")[1].split("```")[0].strip()
215
+
216
+ queryset = self.base_queryset.filter(**self.base_filters)
217
+
218
+ conditions = self.adapter.parse_response(
219
+ raw_json,
220
+ is_schema_based=self.llm_response_format == ResponseFormat.JSON_SCHEMA,
221
+ )
222
+
223
+ if not conditions:
224
+ return queryset.execute() # type: ignore[attr-defined]
225
+
226
+ allowed_lookups = {lookup.value for lookup in Lookup}
227
+ all_filters = {}
228
+
229
+ for cond in conditions:
230
+ if cond.lookup not in allowed_lookups:
231
+ continue
232
+
233
+ filter_key = f"{cond.field}__{cond.lookup}"
234
+ all_filters[filter_key] = cond.value
235
+
236
+ queryset = queryset.filter(**all_filters)
237
+
238
+ return queryset.execute() # type: ignore[attr-defined]
239
+
240
+ def _select_response_format(
241
+ self,
242
+ requested: Optional[ResponseFormat],
243
+ ) -> ResponseFormat:
244
+ if requested:
245
+ return requested
246
+
247
+ supported = self.llm.supported_formats
248
+ if ResponseFormat.JSON_SCHEMA in supported:
249
+ return ResponseFormat.JSON_SCHEMA
250
+ if ResponseFormat.JSON_OBJECT in supported:
251
+ return ResponseFormat.JSON_OBJECT
252
+ return ResponseFormat.PLAIN_TEXT
253
+
254
+ def _build_search_schema(self) -> list[dict[str, Any]]:
255
+ schema = []
256
+ model_fields = self.model_class.model_fields
257
+ fields_to_process = (
258
+ self.fields
259
+ if self.fields is not None
260
+ else [*model_fields.keys(), *self.base_queryset._annotations.keys()]
261
+ )
262
+
263
+ for field_name in fields_to_process:
264
+ if field_name in model_fields:
265
+ field_obj = model_fields[field_name]
266
+ description = self.fields_info.get(
267
+ field_name, field_obj.title or field_name
268
+ )
269
+ field_type = self._get_field_type(field_obj.annotation)
270
+ if field_type is None:
271
+ continue
272
+ elif field_name in self.base_queryset._annotations:
273
+ description = f"{field_name}: computed field based on internal fields"
274
+ field_type = "string"
275
+ else:
276
+ continue
277
+
278
+ schema.append(
279
+ {"name": field_name, "type": field_type, "description": description}
280
+ )
281
+
282
+ return schema
283
+
284
+ def is_skip_type(self, typ: Any, depth: int = 0) -> bool:
285
+ if depth > self.MAX_SCHEMA_INSPECTION_DEPTH:
286
+ return True
287
+
288
+ origin = get_origin(typ)
289
+
290
+ if origin.__name__ == "Annotated" if origin else False:
291
+ return self.is_skip_type(get_args(typ)[0], depth + 1)
292
+
293
+ if origin is dict:
294
+ return True
295
+
296
+ if origin in (list, tuple, set):
297
+ return True
298
+
299
+ if origin and "Union" in str(origin):
300
+ return any(
301
+ self.is_skip_type(arg, depth + 1)
302
+ for arg in get_args(typ)
303
+ if arg is not type(None)
304
+ )
305
+
306
+ try:
307
+ if (
308
+ isinstance(typ, type)
309
+ and issubclass(typ, (BaseModel, Model))
310
+ and not issubclass(typ, Enum)
311
+ ):
312
+ return True
313
+ except TypeError:
314
+ pass
315
+
316
+ return False
317
+
318
+ def _extract_enum_literal(self, typ: Any, depth: int = 0) -> list[str] | None:
319
+ if depth > self.MAX_SCHEMA_INSPECTION_DEPTH:
320
+ return None
321
+
322
+ origin = get_origin(typ)
323
+
324
+ if origin.__name__ == "Annotated" if origin else False:
325
+ return self._extract_enum_literal(get_args(typ)[0], depth + 1)
326
+
327
+ try:
328
+ if isinstance(typ, type) and issubclass(typ, Enum):
329
+ return [e.value for e in typ]
330
+ except TypeError:
331
+ pass
332
+
333
+ if origin is Literal:
334
+ return list(get_args(typ))
335
+
336
+ if origin and "Union" in str(origin):
337
+ all_values = []
338
+ has_primitive = False
339
+ for arg in get_args(typ):
340
+ if arg is type(None):
341
+ continue
342
+ if arg in (str, int, float, bool):
343
+ has_primitive = True
344
+ break
345
+ values = self._extract_enum_literal(arg, depth + 1)
346
+ if values:
347
+ all_values.extend(values)
348
+
349
+ if has_primitive or not all_values:
350
+ return None
351
+ return list(dict.fromkeys(all_values))
352
+
353
+ return None
354
+
355
+ def _get_field_type(self, annotation) -> str | None:
356
+ origin = get_origin(annotation)
357
+
358
+ if origin is dict:
359
+ return None
360
+
361
+ # TODO: Enable list/set/tuple fields when AMSDAL adds JSONB array support
362
+ skip_list_fields = True
363
+ if skip_list_fields and origin in (
364
+ list,
365
+ tuple,
366
+ set,
367
+ ):
368
+ return None
369
+
370
+ if origin in (list, tuple, set):
371
+ args = get_args(annotation)
372
+ if not args:
373
+ return "list of string"
374
+
375
+ inner = args[0]
376
+ if self.is_skip_type(inner):
377
+ return None
378
+
379
+ values = self._extract_enum_literal(inner)
380
+ if values:
381
+ if len(values) > self.MAX_VALUES_DISPLAY_LIMIT:
382
+ values = [*values[: self.MAX_VALUES_DISPLAY_LIMIT], "..."]
383
+ return f"list of options: {', '.join(f'{type(v).__name__}({v!r})' for v in values)}"
384
+
385
+ inner_type = self._get_primitive_type(inner)
386
+ if inner_type is None:
387
+ return None
388
+ return f"list of {inner_type}"
389
+
390
+ if self.is_skip_type(annotation):
391
+ return None
392
+
393
+ values = self._extract_enum_literal(annotation)
394
+ if values:
395
+ if len(values) > self.MAX_VALUES_DISPLAY_LIMIT:
396
+ values = [*values[: self.MAX_VALUES_DISPLAY_LIMIT], "..."]
397
+ return f"options: {', '.join(f'{type(v).__name__}({v!r})' for v in values)}"
398
+
399
+ return self._get_primitive_type(annotation)
400
+
401
+ def _get_primitive_type(self, typ: Any, depth: int = 0) -> str | None:
402
+ if depth > self.MAX_SCHEMA_INSPECTION_DEPTH:
403
+ return "string"
404
+
405
+ origin = get_origin(typ)
406
+
407
+ if origin.__name__ == "Annotated" if origin else False:
408
+ return self._get_primitive_type(get_args(typ)[0], depth + 1)
409
+
410
+ if origin and "Union" in str(origin):
411
+ for arg in get_args(typ):
412
+ if arg is type(None):
413
+ continue
414
+ if arg is int or arg is float:
415
+ return "number"
416
+ if arg is bool:
417
+ return "boolean"
418
+ if arg is str:
419
+ return "string"
420
+ return self._get_primitive_type(arg, depth + 1) or "string"
421
+
422
+ if typ is int or typ is float:
423
+ return "number"
424
+ if typ is bool:
425
+ return "boolean"
426
+ if typ is str:
427
+ return "string"
428
+
429
+ type_name = typ.__name__.lower()
430
+ if type_name in ("int", "float"):
431
+ return "number"
432
+ if type_name == "bool":
433
+ return "boolean"
434
+ if type_name in ("str", "string"):
435
+ return "string"
436
+ if "datetime" in type_name:
437
+ return "datetime (YYYY-MM-DD HH:MM:SS)"
438
+ if "date" in type_name and "time" not in type_name:
439
+ return "date (YYYY-MM-DD)"
440
+
441
+ return None
442
+
443
+ def _build_system_prompt(self) -> str:
444
+ schema_json = json.dumps(self.search_schema, indent=2, ensure_ascii=False)
445
+ prompt = get_prompt('nl_query_filter', custom_path=self.prompt_path)
446
+ return prompt.render_text(schema=schema_json, query='{query}')
447
+
448
+
449
+ class NLQueryRetriever(Retriever, Generic[T]):
450
+ """
451
+ Retriever wrapper for NLQueryExecutor.
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ llm: MLModel,
457
+ queryset: AmsdalQuerySet[T],
458
+ base_filters: Optional[dict[str, Any]] = None,
459
+ fields: Optional[list[str]] = None,
460
+ fields_info: Optional[dict[str, str]] = None,
461
+ prompt_path: Optional[str | Path] = None,
462
+ llm_response_format: Optional[ResponseFormat] = None,
463
+ adapter: Optional[DefaultRetrieverAdapter] = None,
464
+ ) -> None:
465
+ self.executor = NLQueryExecutor(
466
+ llm=llm,
467
+ queryset=queryset,
468
+ base_filters=base_filters,
469
+ fields=fields,
470
+ fields_info=fields_info,
471
+ prompt_path=prompt_path,
472
+ llm_response_format=llm_response_format,
473
+ adapter=adapter,
474
+ )
475
+
476
+ async def invoke(self, query: str, limit: int | None = 30) -> list[Document]:
477
+ results: list[T] = await self.executor.execute(query, limit)
478
+ documents = []
479
+ for item in results:
480
+ cleaned_data = await serialize_and_clean_record(item)
481
+ documents.append(
482
+ Document(
483
+ page_content=json.dumps(cleaned_data, ensure_ascii=False),
484
+ metadata=item.model_dump(),
485
+ )
486
+ )
487
+ return documents
@@ -3,12 +3,23 @@ from __future__ import annotations
3
3
  from abc import ABC
4
4
  from abc import abstractmethod
5
5
  from collections.abc import Iterable
6
+ from typing import Any
6
7
  from typing import Optional
7
8
 
8
9
  from pydantic import BaseModel
9
10
  from pydantic import Field
10
11
 
11
12
 
13
+ class Document(BaseModel):
14
+ page_content: str
15
+ metadata: dict[str, Any] = Field(default_factory=dict)
16
+
17
+
18
+ class Retriever(ABC):
19
+ @abstractmethod
20
+ async def invoke(self, query: str) -> list[Document]: ...
21
+
22
+
12
23
  class RetrievalChunk(BaseModel):
13
24
  object_class: str = Field(...)
14
25
  object_id: str = Field(...)
@@ -16,6 +27,7 @@ class RetrievalChunk(BaseModel):
16
27
  raw_text: str = Field(...)
17
28
  distance: float = Field(...)
18
29
  tags: list[str] = Field(default_factory=list)
30
+ metadata: dict[str, Any] = Field(default_factory=dict)
19
31
 
20
32
 
21
33
  class MLRetriever(ABC):
@@ -8,14 +8,14 @@ from pydantic import Field
8
8
 
9
9
  class EmbeddingModel(Model):
10
10
  __module_type__ = ModuleType.CONTRIB
11
- __table_name__ = "embedding_model"
11
+ __table_name__ = 'embedding_model'
12
12
 
13
- data_object_class: str = Field(..., title="Linked object class")
14
- data_object_id: str = Field(..., title="Linked object ID")
13
+ data_object_class: str = Field(..., title='Linked object class')
14
+ data_object_id: str = Field(..., title='Linked object ID')
15
15
 
16
- chunk_index: int = Field(..., title="Chunk index")
17
- raw_text: str = Field(..., title="Raw text used for embedding")
16
+ chunk_index: int = Field(..., title='Chunk index')
17
+ raw_text: str = Field(..., title='Raw text used for embedding')
18
18
 
19
19
  embedding: VectorField(1536) # type: ignore[valid-type]
20
- tags: list[str] = Field(default_factory=list, title="Embedding tags")
21
- ml_metadata: Any = Field(default=None, title="ML metadata")
20
+ tags: list[str] = Field(default_factory=list, title='Embedding tags')
21
+ ml_metadata: Any = Field(default=None, title='ML metadata')
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+ from typing import Final
7
+
8
+ # avoid "magic numbers"
9
+ _PARTS_EXPECTED: Final[int] = 2
10
+
11
+
12
+ class _SafeDict(dict[str, Any]):
13
+ def __missing__(self, key: str) -> str:
14
+ return '{' + key + '}'
15
+
16
+
17
+ @dataclass
18
+ class Prompt:
19
+ name: str
20
+ system: str
21
+ user: str
22
+
23
+ def render_text(self, **kwargs: Any) -> str:
24
+ data = _SafeDict(**kwargs)
25
+ sys_txt = self.system.format_map(data)
26
+ usr_txt = self.user.format_map(data)
27
+ return f'{sys_txt}\n\n{usr_txt}'.strip()
28
+
29
+ def render_messages(self, **kwargs: Any) -> list[dict[str, str]]:
30
+ data = _SafeDict(**kwargs)
31
+ return [
32
+ {'role': 'system', 'content': self.system.format_map(data)},
33
+ {'role': 'user', 'content': self.user.format_map(data)},
34
+ ]
35
+
36
+
37
+ _prompt_cache: dict[str, Prompt] = {}
38
+
39
+
40
+ def _load_file(name: str, custom_path: Path | None = None) -> Prompt:
41
+ if custom_path:
42
+ path = custom_path
43
+ else:
44
+ base = Path(__file__).resolve().parent
45
+ path = base / f'{name}.prompt'
46
+
47
+ if not path.exists():
48
+ msg = f"Prompt '{name}' not found at {path}"
49
+ raise FileNotFoundError(msg)
50
+
51
+ raw = path.read_text(encoding='utf-8')
52
+ parts = raw.split('\n---\n', 1)
53
+ if len(parts) == _PARTS_EXPECTED:
54
+ system, user = parts
55
+ else:
56
+ system, user = raw, '{input}'
57
+ return Prompt(name=name, system=system.strip(), user=user.strip())
58
+
59
+
60
+ def get_prompt(name: str, custom_path: str | Path | None = None) -> Prompt:
61
+ """
62
+ Load a prompt template by name or from a custom path.
63
+
64
+ Args:
65
+ name: Name of the prompt (without .prompt extension) or identifier
66
+ custom_path: Optional path to custom prompt file. If provided, loads from this path instead.
67
+
68
+ Returns:
69
+ Prompt object with system and user templates
70
+ """
71
+ cache_key = str(custom_path) if custom_path else name
72
+
73
+ if cache_key not in _prompt_cache:
74
+ path_obj = Path(custom_path) if custom_path else None
75
+ _prompt_cache[cache_key] = _load_file(name, path_obj)
76
+
77
+ return _prompt_cache[cache_key]
@@ -0,0 +1,14 @@
1
+ You are a knowledgeable assistant specialized in database queries. Use only the provided context (including any user-attached files) to answer accurately and concisely. If relevant information such as product names, crediting strategies, interest rates, or surrender charge schedules is available, include it. If the context is insufficient, clearly say you don't have enough data and avoid assumptions.
2
+
3
+ CRITICAL TOOL USAGE:
4
+ - For ANY questions that require retrieving specific data from the database, ALWAYS use the 'search_database' tool first. Pass the EXACT user's question as the query parameter without rephrasing or modifying it.
5
+ - If the user asks to display data in a table format, use the 'render_markdown_table' tool after retrieving the data with 'search_database'.
6
+
7
+ IMPORTANT OUTPUT INSTRUCTION:
8
+ - When using the 'render_markdown_table' tool, your Final Answer must contain ONLY the raw Markdown table string.
9
+ - Do NOT add introductory text (e.g., 'Here are the results').
10
+ - Do NOT add concluding remarks.
11
+ - Do NOT wrap the table in markdown code blocks (like ```markdown ... ```). Just output the table directly.
12
+ - If you use 'search_database' without 'render_markdown_table', provide a natural language summary of the results.
13
+ ---
14
+ {input}
@@ -0,0 +1,9 @@
1
+ You are an assistant that can call tools to help answer questions.
2
+
3
+ CRITICAL TOOL USAGE:
4
+ - Use available tools when necessary to provide accurate responses.
5
+
6
+ IMPORTANT OUTPUT INSTRUCTION:
7
+ - Provide clear and concise answers based on tool results.
8
+ ---
9
+ {input}