amsdal_ml 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amsdal_ml/Third-Party Materials - AMSDAL Dependencies - License Notices.md +617 -0
- amsdal_ml/__about__.py +1 -1
- amsdal_ml/agents/__init__.py +13 -0
- amsdal_ml/agents/agent.py +5 -7
- amsdal_ml/agents/default_qa_agent.py +108 -143
- amsdal_ml/agents/functional_calling_agent.py +233 -0
- amsdal_ml/agents/mcp_client_tool.py +46 -0
- amsdal_ml/agents/python_tool.py +86 -0
- amsdal_ml/agents/retriever_tool.py +17 -8
- amsdal_ml/agents/tool_adapters.py +98 -0
- amsdal_ml/fileio/base_loader.py +7 -5
- amsdal_ml/fileio/openai_loader.py +16 -17
- amsdal_ml/mcp_client/base.py +2 -0
- amsdal_ml/mcp_client/http_client.py +7 -1
- amsdal_ml/mcp_client/stdio_client.py +21 -18
- amsdal_ml/mcp_server/server_retriever_stdio.py +8 -11
- amsdal_ml/ml_ingesting/__init__.py +29 -0
- amsdal_ml/ml_ingesting/default_ingesting.py +49 -51
- amsdal_ml/ml_ingesting/embedders/__init__.py +4 -0
- amsdal_ml/ml_ingesting/embedders/embedder.py +12 -0
- amsdal_ml/ml_ingesting/embedders/openai_embedder.py +30 -0
- amsdal_ml/ml_ingesting/embedding_data.py +3 -0
- amsdal_ml/ml_ingesting/loaders/__init__.py +6 -0
- amsdal_ml/ml_ingesting/loaders/folder_loader.py +52 -0
- amsdal_ml/ml_ingesting/loaders/loader.py +28 -0
- amsdal_ml/ml_ingesting/loaders/pdf_loader.py +136 -0
- amsdal_ml/ml_ingesting/loaders/text_loader.py +44 -0
- amsdal_ml/ml_ingesting/model_ingester.py +278 -0
- amsdal_ml/ml_ingesting/pipeline.py +131 -0
- amsdal_ml/ml_ingesting/pipeline_interface.py +31 -0
- amsdal_ml/ml_ingesting/processors/__init__.py +4 -0
- amsdal_ml/ml_ingesting/processors/cleaner.py +14 -0
- amsdal_ml/ml_ingesting/processors/text_cleaner.py +42 -0
- amsdal_ml/ml_ingesting/splitters/__init__.py +4 -0
- amsdal_ml/ml_ingesting/splitters/splitter.py +15 -0
- amsdal_ml/ml_ingesting/splitters/token_splitter.py +85 -0
- amsdal_ml/ml_ingesting/stores/__init__.py +4 -0
- amsdal_ml/ml_ingesting/stores/embedding_data.py +63 -0
- amsdal_ml/ml_ingesting/stores/store.py +22 -0
- amsdal_ml/ml_ingesting/types.py +40 -0
- amsdal_ml/ml_models/models.py +96 -4
- amsdal_ml/ml_models/openai_model.py +430 -122
- amsdal_ml/ml_models/utils.py +7 -0
- amsdal_ml/ml_retrievers/__init__.py +17 -0
- amsdal_ml/ml_retrievers/adapters.py +93 -0
- amsdal_ml/ml_retrievers/default_retriever.py +11 -1
- amsdal_ml/ml_retrievers/openai_retriever.py +27 -7
- amsdal_ml/ml_retrievers/query_retriever.py +487 -0
- amsdal_ml/ml_retrievers/retriever.py +12 -0
- amsdal_ml/models/embedding_model.py +7 -7
- amsdal_ml/prompts/__init__.py +77 -0
- amsdal_ml/prompts/database_query_agent.prompt +14 -0
- amsdal_ml/prompts/functional_calling_agent_base.prompt +9 -0
- amsdal_ml/prompts/nl_query_filter.prompt +318 -0
- amsdal_ml/{agents/promts → prompts}/react_chat.prompt +17 -8
- amsdal_ml/utils/__init__.py +5 -0
- amsdal_ml/utils/query_utils.py +189 -0
- amsdal_ml-0.2.0.dist-info/METADATA +293 -0
- amsdal_ml-0.2.0.dist-info/RECORD +72 -0
- {amsdal_ml-0.1.3.dist-info → amsdal_ml-0.2.0.dist-info}/WHEEL +1 -1
- amsdal_ml/agents/promts/__init__.py +0 -58
- amsdal_ml-0.1.3.dist-info/METADATA +0 -69
- amsdal_ml-0.1.3.dist-info/RECORD +0 -39
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
"""Natural language query interface for Amsdal QuerySets."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import date
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing import Generic
|
|
10
|
+
from typing import Literal
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from typing import TypeVar
|
|
13
|
+
from typing import get_args
|
|
14
|
+
from typing import get_origin
|
|
15
|
+
|
|
16
|
+
from amsdal_models.classes.model import Model
|
|
17
|
+
from amsdal_models.querysets.base_queryset import QuerySetBase
|
|
18
|
+
from amsdal_utils.query.enums import Lookup
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
from amsdal_ml.ml_models.models import MLModel
|
|
22
|
+
from amsdal_ml.ml_models.utils import ResponseFormat
|
|
23
|
+
from amsdal_ml.ml_retrievers.adapters import DefaultRetrieverAdapter
|
|
24
|
+
from amsdal_ml.ml_retrievers.adapters import get_retriever_adapter
|
|
25
|
+
from amsdal_ml.ml_retrievers.retriever import Document
|
|
26
|
+
from amsdal_ml.ml_retrievers.retriever import Retriever
|
|
27
|
+
from amsdal_ml.prompts import get_prompt
|
|
28
|
+
from amsdal_ml.utils.query_utils import serialize_and_clean_record
|
|
29
|
+
|
|
30
|
+
T = TypeVar("T", bound=Model)
|
|
31
|
+
AmsdalQuerySet = QuerySetBase[T]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FilterCondition(BaseModel):
|
|
35
|
+
"""Single filter condition for database query."""
|
|
36
|
+
|
|
37
|
+
field: str
|
|
38
|
+
lookup: str
|
|
39
|
+
value: str | int | float | bool | None | date | datetime | list[str | int | float | bool | date | datetime]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FilterResponse(BaseModel):
|
|
43
|
+
"""Structured response for LLM."""
|
|
44
|
+
|
|
45
|
+
filters: list[FilterCondition]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class NLQueryExecutor:
|
|
49
|
+
"""
|
|
50
|
+
Natural language query interface for Amsdal QuerySets.
|
|
51
|
+
|
|
52
|
+
Converts natural language queries into structured database filters using an LLM.
|
|
53
|
+
|
|
54
|
+
Supported Field Types:
|
|
55
|
+
- Primitives: str, int, float, bool
|
|
56
|
+
- Enums: Enum subclasses
|
|
57
|
+
- Literals: Literal[...]
|
|
58
|
+
- Unions: Union of supported types (no mixed with models/dict/lists/tuples/sets)
|
|
59
|
+
- Annotated: Annotated[supported_type, ...]
|
|
60
|
+
- Dates: datetime, date
|
|
61
|
+
|
|
62
|
+
Unsupported Field Types (will be skipped):
|
|
63
|
+
- Models: BaseModel/Model subclasses
|
|
64
|
+
- Dicts: dict[...]
|
|
65
|
+
- Lists/Tuples/Sets: list/tuple/set[...]
|
|
66
|
+
- Callables: Callable[...]
|
|
67
|
+
- Unknown types: bytes, custom classes, etc.
|
|
68
|
+
Example:
|
|
69
|
+
>>> llm = OpenAIModel()
|
|
70
|
+
>>> llm.setup()
|
|
71
|
+
>>> base_qs = BaseRateInfo.objects.filter(accessibility__eq='public')
|
|
72
|
+
>>> executor = NLQueryExecutor(
|
|
73
|
+
... llm=llm,
|
|
74
|
+
... queryset=base_qs,
|
|
75
|
+
... base_filters={'status__eq': 'active'},
|
|
76
|
+
... fields=['rate', 'premium_band', 'effective_date', 'rate_category'],
|
|
77
|
+
... fields_info={'rate': 'Current rate value as decimal (5.15% = 0.0515)'}
|
|
78
|
+
... )
|
|
79
|
+
>>> results = await executor.execute("show me all high band rates")
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
_full_prompt: Optional[str]
|
|
83
|
+
MAX_VALUES_DISPLAY_LIMIT = 100
|
|
84
|
+
MAX_SCHEMA_INSPECTION_DEPTH = 10
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
llm: MLModel,
|
|
89
|
+
queryset: AmsdalQuerySet[T],
|
|
90
|
+
base_filters: Optional[dict[str, Any]] = None,
|
|
91
|
+
fields: Optional[list[str]] = None,
|
|
92
|
+
fields_info: Optional[dict[str, str]] = None,
|
|
93
|
+
prompt_path: Optional[str | Path] = None,
|
|
94
|
+
llm_response_format: Optional[ResponseFormat] = None,
|
|
95
|
+
adapter: Optional[DefaultRetrieverAdapter] = None,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""
|
|
98
|
+
Initialize NLQueryExecutor.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
llm: ML model instance for query interpretation
|
|
102
|
+
queryset: AMSDAL QuerySet
|
|
103
|
+
base_filters: Optional dict of additional base filters to apply to all queries
|
|
104
|
+
(e.g., {'accessibility__eq': 'public'})
|
|
105
|
+
fields: Optional list of field names to include in search schema.
|
|
106
|
+
If None, all model fields are used.
|
|
107
|
+
fields_info: Optional dict mapping field names to descriptions.
|
|
108
|
+
Takes priority over model field titles.
|
|
109
|
+
Format: {'field_name': 'description text'}
|
|
110
|
+
prompt_path: Optional path to custom prompt file. If None, uses default 'nl_query_filter'
|
|
111
|
+
which is specifically designed for AMSDAL ORM with comprehensive coverage of all
|
|
112
|
+
lookup operators, field type conversions, and detailed examples for
|
|
113
|
+
optimal query interpretation.
|
|
114
|
+
llm_response_format: Optional desired response format. Can be ResponseFormat enum or string
|
|
115
|
+
('JSON_SCHEMA', 'JSON_OBJECT', 'PLAIN_TEXT'). If None, will auto-detect best.
|
|
116
|
+
adapter: Optional adapter instance. If None, will auto-detect based on LLM type.
|
|
117
|
+
"""
|
|
118
|
+
self.llm = llm
|
|
119
|
+
self.base_queryset = queryset
|
|
120
|
+
self.model_class = queryset.entity
|
|
121
|
+
self.base_filters = base_filters or {}
|
|
122
|
+
self.fields = fields
|
|
123
|
+
self.fields_info = fields_info or {}
|
|
124
|
+
self.prompt_path = prompt_path
|
|
125
|
+
self.adapter = adapter or get_retriever_adapter(llm)
|
|
126
|
+
self.llm_response_format = self._select_response_format(llm_response_format)
|
|
127
|
+
|
|
128
|
+
self.search_schema = self._build_search_schema()
|
|
129
|
+
self.system_prompt = self._build_system_prompt()
|
|
130
|
+
self._full_prompt = None
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def response_schema(self) -> dict[str, Any]:
|
|
134
|
+
"""
|
|
135
|
+
Generates a JSON schema for the response.
|
|
136
|
+
"""
|
|
137
|
+
base_schema = FilterResponse.model_json_schema()
|
|
138
|
+
return self.adapter.get_response_schema(base_schema)
|
|
139
|
+
|
|
140
|
+
async def execute(self, query: str, limit : int | None = 30) -> list[T]:
|
|
141
|
+
results: list[T] = await self.asearch(query)
|
|
142
|
+
return results[:limit] if limit is not None else results
|
|
143
|
+
|
|
144
|
+
async def asearch(self, query: str) -> list[T]:
|
|
145
|
+
prompt = get_prompt("nl_query_filter", custom_path=self.prompt_path)
|
|
146
|
+
full_prompt = prompt.render_text(
|
|
147
|
+
schema=json.dumps(self.search_schema, indent=2), query=query
|
|
148
|
+
)
|
|
149
|
+
self._full_prompt = full_prompt
|
|
150
|
+
|
|
151
|
+
response = await self.llm.ainvoke(
|
|
152
|
+
full_prompt,
|
|
153
|
+
response_format=self.llm_response_format,
|
|
154
|
+
schema=(
|
|
155
|
+
self.response_schema
|
|
156
|
+
if self.llm_response_format == ResponseFormat.JSON_SCHEMA
|
|
157
|
+
else None
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
raw_json = response.strip()
|
|
161
|
+
|
|
162
|
+
if self.llm_response_format == ResponseFormat.PLAIN_TEXT:
|
|
163
|
+
if "```json" in raw_json:
|
|
164
|
+
raw_json = raw_json.split("```json")[1].split("```")[0].strip()
|
|
165
|
+
elif "```" in raw_json:
|
|
166
|
+
raw_json = raw_json.split("```")[1].split("```")[0].strip()
|
|
167
|
+
|
|
168
|
+
queryset = self.base_queryset.filter(**self.base_filters)
|
|
169
|
+
|
|
170
|
+
conditions = self.adapter.parse_response(
|
|
171
|
+
raw_json,
|
|
172
|
+
is_schema_based=self.llm_response_format == ResponseFormat.JSON_SCHEMA,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if not conditions:
|
|
176
|
+
return await queryset.aexecute() # type: ignore[attr-defined]
|
|
177
|
+
|
|
178
|
+
allowed_lookups = {lookup.value for lookup in Lookup}
|
|
179
|
+
all_filters = {}
|
|
180
|
+
|
|
181
|
+
for cond in conditions:
|
|
182
|
+
if cond.lookup not in allowed_lookups:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
filter_key = f"{cond.field}__{cond.lookup}"
|
|
186
|
+
all_filters[filter_key] = cond.value
|
|
187
|
+
|
|
188
|
+
queryset = queryset.filter(**all_filters)
|
|
189
|
+
|
|
190
|
+
return await queryset.aexecute() # type: ignore[attr-defined]
|
|
191
|
+
|
|
192
|
+
def search(self, query: str) -> list[T]:
|
|
193
|
+
prompt = get_prompt("nl_query_filter", custom_path=self.prompt_path)
|
|
194
|
+
full_prompt = prompt.render_text(
|
|
195
|
+
schema=json.dumps(self.search_schema, indent=2), query=query
|
|
196
|
+
)
|
|
197
|
+
self._full_prompt = full_prompt
|
|
198
|
+
|
|
199
|
+
response = self.llm.invoke(
|
|
200
|
+
full_prompt,
|
|
201
|
+
response_format=self.llm_response_format,
|
|
202
|
+
schema=(
|
|
203
|
+
self.response_schema
|
|
204
|
+
if self.llm_response_format == ResponseFormat.JSON_SCHEMA
|
|
205
|
+
else None
|
|
206
|
+
),
|
|
207
|
+
)
|
|
208
|
+
raw_json = response.strip()
|
|
209
|
+
|
|
210
|
+
if self.llm_response_format == ResponseFormat.PLAIN_TEXT:
|
|
211
|
+
if "```json" in raw_json:
|
|
212
|
+
raw_json = raw_json.split("```json")[1].split("```")[0].strip()
|
|
213
|
+
elif "```" in raw_json:
|
|
214
|
+
raw_json = raw_json.split("```")[1].split("```")[0].strip()
|
|
215
|
+
|
|
216
|
+
queryset = self.base_queryset.filter(**self.base_filters)
|
|
217
|
+
|
|
218
|
+
conditions = self.adapter.parse_response(
|
|
219
|
+
raw_json,
|
|
220
|
+
is_schema_based=self.llm_response_format == ResponseFormat.JSON_SCHEMA,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if not conditions:
|
|
224
|
+
return queryset.execute() # type: ignore[attr-defined]
|
|
225
|
+
|
|
226
|
+
allowed_lookups = {lookup.value for lookup in Lookup}
|
|
227
|
+
all_filters = {}
|
|
228
|
+
|
|
229
|
+
for cond in conditions:
|
|
230
|
+
if cond.lookup not in allowed_lookups:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
filter_key = f"{cond.field}__{cond.lookup}"
|
|
234
|
+
all_filters[filter_key] = cond.value
|
|
235
|
+
|
|
236
|
+
queryset = queryset.filter(**all_filters)
|
|
237
|
+
|
|
238
|
+
return queryset.execute() # type: ignore[attr-defined]
|
|
239
|
+
|
|
240
|
+
def _select_response_format(
|
|
241
|
+
self,
|
|
242
|
+
requested: Optional[ResponseFormat],
|
|
243
|
+
) -> ResponseFormat:
|
|
244
|
+
if requested:
|
|
245
|
+
return requested
|
|
246
|
+
|
|
247
|
+
supported = self.llm.supported_formats
|
|
248
|
+
if ResponseFormat.JSON_SCHEMA in supported:
|
|
249
|
+
return ResponseFormat.JSON_SCHEMA
|
|
250
|
+
if ResponseFormat.JSON_OBJECT in supported:
|
|
251
|
+
return ResponseFormat.JSON_OBJECT
|
|
252
|
+
return ResponseFormat.PLAIN_TEXT
|
|
253
|
+
|
|
254
|
+
def _build_search_schema(self) -> list[dict[str, Any]]:
|
|
255
|
+
schema = []
|
|
256
|
+
model_fields = self.model_class.model_fields
|
|
257
|
+
fields_to_process = (
|
|
258
|
+
self.fields
|
|
259
|
+
if self.fields is not None
|
|
260
|
+
else [*model_fields.keys(), *self.base_queryset._annotations.keys()]
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
for field_name in fields_to_process:
|
|
264
|
+
if field_name in model_fields:
|
|
265
|
+
field_obj = model_fields[field_name]
|
|
266
|
+
description = self.fields_info.get(
|
|
267
|
+
field_name, field_obj.title or field_name
|
|
268
|
+
)
|
|
269
|
+
field_type = self._get_field_type(field_obj.annotation)
|
|
270
|
+
if field_type is None:
|
|
271
|
+
continue
|
|
272
|
+
elif field_name in self.base_queryset._annotations:
|
|
273
|
+
description = f"{field_name}: computed field based on internal fields"
|
|
274
|
+
field_type = "string"
|
|
275
|
+
else:
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
schema.append(
|
|
279
|
+
{"name": field_name, "type": field_type, "description": description}
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
return schema
|
|
283
|
+
|
|
284
|
+
def is_skip_type(self, typ: Any, depth: int = 0) -> bool:
|
|
285
|
+
if depth > self.MAX_SCHEMA_INSPECTION_DEPTH:
|
|
286
|
+
return True
|
|
287
|
+
|
|
288
|
+
origin = get_origin(typ)
|
|
289
|
+
|
|
290
|
+
if origin.__name__ == "Annotated" if origin else False:
|
|
291
|
+
return self.is_skip_type(get_args(typ)[0], depth + 1)
|
|
292
|
+
|
|
293
|
+
if origin is dict:
|
|
294
|
+
return True
|
|
295
|
+
|
|
296
|
+
if origin in (list, tuple, set):
|
|
297
|
+
return True
|
|
298
|
+
|
|
299
|
+
if origin and "Union" in str(origin):
|
|
300
|
+
return any(
|
|
301
|
+
self.is_skip_type(arg, depth + 1)
|
|
302
|
+
for arg in get_args(typ)
|
|
303
|
+
if arg is not type(None)
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
if (
|
|
308
|
+
isinstance(typ, type)
|
|
309
|
+
and issubclass(typ, (BaseModel, Model))
|
|
310
|
+
and not issubclass(typ, Enum)
|
|
311
|
+
):
|
|
312
|
+
return True
|
|
313
|
+
except TypeError:
|
|
314
|
+
pass
|
|
315
|
+
|
|
316
|
+
return False
|
|
317
|
+
|
|
318
|
+
def _extract_enum_literal(self, typ: Any, depth: int = 0) -> list[str] | None:
|
|
319
|
+
if depth > self.MAX_SCHEMA_INSPECTION_DEPTH:
|
|
320
|
+
return None
|
|
321
|
+
|
|
322
|
+
origin = get_origin(typ)
|
|
323
|
+
|
|
324
|
+
if origin.__name__ == "Annotated" if origin else False:
|
|
325
|
+
return self._extract_enum_literal(get_args(typ)[0], depth + 1)
|
|
326
|
+
|
|
327
|
+
try:
|
|
328
|
+
if isinstance(typ, type) and issubclass(typ, Enum):
|
|
329
|
+
return [e.value for e in typ]
|
|
330
|
+
except TypeError:
|
|
331
|
+
pass
|
|
332
|
+
|
|
333
|
+
if origin is Literal:
|
|
334
|
+
return list(get_args(typ))
|
|
335
|
+
|
|
336
|
+
if origin and "Union" in str(origin):
|
|
337
|
+
all_values = []
|
|
338
|
+
has_primitive = False
|
|
339
|
+
for arg in get_args(typ):
|
|
340
|
+
if arg is type(None):
|
|
341
|
+
continue
|
|
342
|
+
if arg in (str, int, float, bool):
|
|
343
|
+
has_primitive = True
|
|
344
|
+
break
|
|
345
|
+
values = self._extract_enum_literal(arg, depth + 1)
|
|
346
|
+
if values:
|
|
347
|
+
all_values.extend(values)
|
|
348
|
+
|
|
349
|
+
if has_primitive or not all_values:
|
|
350
|
+
return None
|
|
351
|
+
return list(dict.fromkeys(all_values))
|
|
352
|
+
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
def _get_field_type(self, annotation) -> str | None:
|
|
356
|
+
origin = get_origin(annotation)
|
|
357
|
+
|
|
358
|
+
if origin is dict:
|
|
359
|
+
return None
|
|
360
|
+
|
|
361
|
+
# TODO: Enable list/set/tuple fields when AMSDAL adds JSONB array support
|
|
362
|
+
skip_list_fields = True
|
|
363
|
+
if skip_list_fields and origin in (
|
|
364
|
+
list,
|
|
365
|
+
tuple,
|
|
366
|
+
set,
|
|
367
|
+
):
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
if origin in (list, tuple, set):
|
|
371
|
+
args = get_args(annotation)
|
|
372
|
+
if not args:
|
|
373
|
+
return "list of string"
|
|
374
|
+
|
|
375
|
+
inner = args[0]
|
|
376
|
+
if self.is_skip_type(inner):
|
|
377
|
+
return None
|
|
378
|
+
|
|
379
|
+
values = self._extract_enum_literal(inner)
|
|
380
|
+
if values:
|
|
381
|
+
if len(values) > self.MAX_VALUES_DISPLAY_LIMIT:
|
|
382
|
+
values = [*values[: self.MAX_VALUES_DISPLAY_LIMIT], "..."]
|
|
383
|
+
return f"list of options: {', '.join(f'{type(v).__name__}({v!r})' for v in values)}"
|
|
384
|
+
|
|
385
|
+
inner_type = self._get_primitive_type(inner)
|
|
386
|
+
if inner_type is None:
|
|
387
|
+
return None
|
|
388
|
+
return f"list of {inner_type}"
|
|
389
|
+
|
|
390
|
+
if self.is_skip_type(annotation):
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
values = self._extract_enum_literal(annotation)
|
|
394
|
+
if values:
|
|
395
|
+
if len(values) > self.MAX_VALUES_DISPLAY_LIMIT:
|
|
396
|
+
values = [*values[: self.MAX_VALUES_DISPLAY_LIMIT], "..."]
|
|
397
|
+
return f"options: {', '.join(f'{type(v).__name__}({v!r})' for v in values)}"
|
|
398
|
+
|
|
399
|
+
return self._get_primitive_type(annotation)
|
|
400
|
+
|
|
401
|
+
def _get_primitive_type(self, typ: Any, depth: int = 0) -> str | None:
|
|
402
|
+
if depth > self.MAX_SCHEMA_INSPECTION_DEPTH:
|
|
403
|
+
return "string"
|
|
404
|
+
|
|
405
|
+
origin = get_origin(typ)
|
|
406
|
+
|
|
407
|
+
if origin.__name__ == "Annotated" if origin else False:
|
|
408
|
+
return self._get_primitive_type(get_args(typ)[0], depth + 1)
|
|
409
|
+
|
|
410
|
+
if origin and "Union" in str(origin):
|
|
411
|
+
for arg in get_args(typ):
|
|
412
|
+
if arg is type(None):
|
|
413
|
+
continue
|
|
414
|
+
if arg is int or arg is float:
|
|
415
|
+
return "number"
|
|
416
|
+
if arg is bool:
|
|
417
|
+
return "boolean"
|
|
418
|
+
if arg is str:
|
|
419
|
+
return "string"
|
|
420
|
+
return self._get_primitive_type(arg, depth + 1) or "string"
|
|
421
|
+
|
|
422
|
+
if typ is int or typ is float:
|
|
423
|
+
return "number"
|
|
424
|
+
if typ is bool:
|
|
425
|
+
return "boolean"
|
|
426
|
+
if typ is str:
|
|
427
|
+
return "string"
|
|
428
|
+
|
|
429
|
+
type_name = typ.__name__.lower()
|
|
430
|
+
if type_name in ("int", "float"):
|
|
431
|
+
return "number"
|
|
432
|
+
if type_name == "bool":
|
|
433
|
+
return "boolean"
|
|
434
|
+
if type_name in ("str", "string"):
|
|
435
|
+
return "string"
|
|
436
|
+
if "datetime" in type_name:
|
|
437
|
+
return "datetime (YYYY-MM-DD HH:MM:SS)"
|
|
438
|
+
if "date" in type_name and "time" not in type_name:
|
|
439
|
+
return "date (YYYY-MM-DD)"
|
|
440
|
+
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
def _build_system_prompt(self) -> str:
|
|
444
|
+
schema_json = json.dumps(self.search_schema, indent=2, ensure_ascii=False)
|
|
445
|
+
prompt = get_prompt('nl_query_filter', custom_path=self.prompt_path)
|
|
446
|
+
return prompt.render_text(schema=schema_json, query='{query}')
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class NLQueryRetriever(Retriever, Generic[T]):
|
|
450
|
+
"""
|
|
451
|
+
Retriever wrapper for NLQueryExecutor.
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
def __init__(
|
|
455
|
+
self,
|
|
456
|
+
llm: MLModel,
|
|
457
|
+
queryset: AmsdalQuerySet[T],
|
|
458
|
+
base_filters: Optional[dict[str, Any]] = None,
|
|
459
|
+
fields: Optional[list[str]] = None,
|
|
460
|
+
fields_info: Optional[dict[str, str]] = None,
|
|
461
|
+
prompt_path: Optional[str | Path] = None,
|
|
462
|
+
llm_response_format: Optional[ResponseFormat] = None,
|
|
463
|
+
adapter: Optional[DefaultRetrieverAdapter] = None,
|
|
464
|
+
) -> None:
|
|
465
|
+
self.executor = NLQueryExecutor(
|
|
466
|
+
llm=llm,
|
|
467
|
+
queryset=queryset,
|
|
468
|
+
base_filters=base_filters,
|
|
469
|
+
fields=fields,
|
|
470
|
+
fields_info=fields_info,
|
|
471
|
+
prompt_path=prompt_path,
|
|
472
|
+
llm_response_format=llm_response_format,
|
|
473
|
+
adapter=adapter,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
async def invoke(self, query: str, limit: int | None = 30) -> list[Document]:
|
|
477
|
+
results: list[T] = await self.executor.execute(query, limit)
|
|
478
|
+
documents = []
|
|
479
|
+
for item in results:
|
|
480
|
+
cleaned_data = await serialize_and_clean_record(item)
|
|
481
|
+
documents.append(
|
|
482
|
+
Document(
|
|
483
|
+
page_content=json.dumps(cleaned_data, ensure_ascii=False),
|
|
484
|
+
metadata=item.model_dump(),
|
|
485
|
+
)
|
|
486
|
+
)
|
|
487
|
+
return documents
|
|
@@ -3,12 +3,23 @@ from __future__ import annotations
|
|
|
3
3
|
from abc import ABC
|
|
4
4
|
from abc import abstractmethod
|
|
5
5
|
from collections.abc import Iterable
|
|
6
|
+
from typing import Any
|
|
6
7
|
from typing import Optional
|
|
7
8
|
|
|
8
9
|
from pydantic import BaseModel
|
|
9
10
|
from pydantic import Field
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
class Document(BaseModel):
|
|
14
|
+
page_content: str
|
|
15
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Retriever(ABC):
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def invoke(self, query: str) -> list[Document]: ...
|
|
21
|
+
|
|
22
|
+
|
|
12
23
|
class RetrievalChunk(BaseModel):
|
|
13
24
|
object_class: str = Field(...)
|
|
14
25
|
object_id: str = Field(...)
|
|
@@ -16,6 +27,7 @@ class RetrievalChunk(BaseModel):
|
|
|
16
27
|
raw_text: str = Field(...)
|
|
17
28
|
distance: float = Field(...)
|
|
18
29
|
tags: list[str] = Field(default_factory=list)
|
|
30
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
19
31
|
|
|
20
32
|
|
|
21
33
|
class MLRetriever(ABC):
|
|
@@ -8,14 +8,14 @@ from pydantic import Field
|
|
|
8
8
|
|
|
9
9
|
class EmbeddingModel(Model):
|
|
10
10
|
__module_type__ = ModuleType.CONTRIB
|
|
11
|
-
__table_name__ =
|
|
11
|
+
__table_name__ = 'embedding_model'
|
|
12
12
|
|
|
13
|
-
data_object_class: str = Field(..., title=
|
|
14
|
-
data_object_id: str = Field(..., title=
|
|
13
|
+
data_object_class: str = Field(..., title='Linked object class')
|
|
14
|
+
data_object_id: str = Field(..., title='Linked object ID')
|
|
15
15
|
|
|
16
|
-
chunk_index: int = Field(..., title=
|
|
17
|
-
raw_text: str = Field(..., title=
|
|
16
|
+
chunk_index: int = Field(..., title='Chunk index')
|
|
17
|
+
raw_text: str = Field(..., title='Raw text used for embedding')
|
|
18
18
|
|
|
19
19
|
embedding: VectorField(1536) # type: ignore[valid-type]
|
|
20
|
-
tags: list[str] = Field(default_factory=list, title=
|
|
21
|
-
ml_metadata: Any = Field(default=None, title=
|
|
20
|
+
tags: list[str] = Field(default_factory=list, title='Embedding tags')
|
|
21
|
+
ml_metadata: Any = Field(default=None, title='ML metadata')
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing import Final
|
|
7
|
+
|
|
8
|
+
# avoid "magic numbers"
|
|
9
|
+
_PARTS_EXPECTED: Final[int] = 2
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _SafeDict(dict[str, Any]):
|
|
13
|
+
def __missing__(self, key: str) -> str:
|
|
14
|
+
return '{' + key + '}'
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Prompt:
|
|
19
|
+
name: str
|
|
20
|
+
system: str
|
|
21
|
+
user: str
|
|
22
|
+
|
|
23
|
+
def render_text(self, **kwargs: Any) -> str:
|
|
24
|
+
data = _SafeDict(**kwargs)
|
|
25
|
+
sys_txt = self.system.format_map(data)
|
|
26
|
+
usr_txt = self.user.format_map(data)
|
|
27
|
+
return f'{sys_txt}\n\n{usr_txt}'.strip()
|
|
28
|
+
|
|
29
|
+
def render_messages(self, **kwargs: Any) -> list[dict[str, str]]:
|
|
30
|
+
data = _SafeDict(**kwargs)
|
|
31
|
+
return [
|
|
32
|
+
{'role': 'system', 'content': self.system.format_map(data)},
|
|
33
|
+
{'role': 'user', 'content': self.user.format_map(data)},
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_prompt_cache: dict[str, Prompt] = {}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _load_file(name: str, custom_path: Path | None = None) -> Prompt:
|
|
41
|
+
if custom_path:
|
|
42
|
+
path = custom_path
|
|
43
|
+
else:
|
|
44
|
+
base = Path(__file__).resolve().parent
|
|
45
|
+
path = base / f'{name}.prompt'
|
|
46
|
+
|
|
47
|
+
if not path.exists():
|
|
48
|
+
msg = f"Prompt '{name}' not found at {path}"
|
|
49
|
+
raise FileNotFoundError(msg)
|
|
50
|
+
|
|
51
|
+
raw = path.read_text(encoding='utf-8')
|
|
52
|
+
parts = raw.split('\n---\n', 1)
|
|
53
|
+
if len(parts) == _PARTS_EXPECTED:
|
|
54
|
+
system, user = parts
|
|
55
|
+
else:
|
|
56
|
+
system, user = raw, '{input}'
|
|
57
|
+
return Prompt(name=name, system=system.strip(), user=user.strip())
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_prompt(name: str, custom_path: str | Path | None = None) -> Prompt:
|
|
61
|
+
"""
|
|
62
|
+
Load a prompt template by name or from a custom path.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
name: Name of the prompt (without .prompt extension) or identifier
|
|
66
|
+
custom_path: Optional path to custom prompt file. If provided, loads from this path instead.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Prompt object with system and user templates
|
|
70
|
+
"""
|
|
71
|
+
cache_key = str(custom_path) if custom_path else name
|
|
72
|
+
|
|
73
|
+
if cache_key not in _prompt_cache:
|
|
74
|
+
path_obj = Path(custom_path) if custom_path else None
|
|
75
|
+
_prompt_cache[cache_key] = _load_file(name, path_obj)
|
|
76
|
+
|
|
77
|
+
return _prompt_cache[cache_key]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
You are a knowledgeable assistant specialized in database queries. Use only the provided context (including any user-attached files) to answer accurately and concisely. If relevant information such as product names, crediting strategies, interest rates, or surrender charge schedules is available, include it. If the context is insufficient, clearly say you don't have enough data and avoid assumptions.
|
|
2
|
+
|
|
3
|
+
CRITICAL TOOL USAGE:
|
|
4
|
+
- For ANY questions that require retrieving specific data from the database, ALWAYS use the 'search_database' tool first. Pass the EXACT user's question as the query parameter without rephrasing or modifying it.
|
|
5
|
+
- If the user asks to display data in a table format, use the 'render_markdown_table' tool after retrieving the data with 'search_database'.
|
|
6
|
+
|
|
7
|
+
IMPORTANT OUTPUT INSTRUCTION:
|
|
8
|
+
- When using the 'render_markdown_table' tool, your Final Answer must contain ONLY the raw Markdown table string.
|
|
9
|
+
- Do NOT add introductory text (e.g., 'Here are the results').
|
|
10
|
+
- Do NOT add concluding remarks.
|
|
11
|
+
- Do NOT wrap the table in markdown code blocks (like ```markdown ... ```). Just output the table directly.
|
|
12
|
+
- If you use 'search_database' without 'render_markdown_table', provide a natural language summary of the results.
|
|
13
|
+
---
|
|
14
|
+
{input}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
You are an assistant that can call tools to help answer questions.
|
|
2
|
+
|
|
3
|
+
CRITICAL TOOL USAGE:
|
|
4
|
+
- Use available tools when necessary to provide accurate responses.
|
|
5
|
+
|
|
6
|
+
IMPORTANT OUTPUT INSTRUCTION:
|
|
7
|
+
- Provide clear and concise answers based on tool results.
|
|
8
|
+
---
|
|
9
|
+
{input}
|