kiarina-lib-redisearch 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiarina/lib/redisearch/__init__.py +35 -0
- kiarina/lib/redisearch/_async/__init__.py +0 -0
- kiarina/lib/redisearch/_async/client.py +181 -0
- kiarina/lib/redisearch/_async/registry.py +16 -0
- kiarina/lib/redisearch/_core/__init__.py +0 -0
- kiarina/lib/redisearch/_core/context.py +69 -0
- kiarina/lib/redisearch/_core/operations/__init__.py +0 -0
- kiarina/lib/redisearch/_core/operations/count.py +55 -0
- kiarina/lib/redisearch/_core/operations/create_index.py +52 -0
- kiarina/lib/redisearch/_core/operations/delete.py +43 -0
- kiarina/lib/redisearch/_core/operations/drop_index.py +59 -0
- kiarina/lib/redisearch/_core/operations/exists_index.py +56 -0
- kiarina/lib/redisearch/_core/operations/find.py +105 -0
- kiarina/lib/redisearch/_core/operations/get.py +61 -0
- kiarina/lib/redisearch/_core/operations/get_info.py +155 -0
- kiarina/lib/redisearch/_core/operations/get_key.py +8 -0
- kiarina/lib/redisearch/_core/operations/migrate_index.py +160 -0
- kiarina/lib/redisearch/_core/operations/reset_index.py +60 -0
- kiarina/lib/redisearch/_core/operations/search.py +111 -0
- kiarina/lib/redisearch/_core/operations/set.py +65 -0
- kiarina/lib/redisearch/_core/utils/__init__.py +0 -0
- kiarina/lib/redisearch/_core/utils/calc_score.py +35 -0
- kiarina/lib/redisearch/_core/utils/marshal_mappings.py +57 -0
- kiarina/lib/redisearch/_core/utils/parse_search_result.py +57 -0
- kiarina/lib/redisearch/_core/utils/unmarshal_mappings.py +57 -0
- kiarina/lib/redisearch/_core/views/__init__.py +0 -0
- kiarina/lib/redisearch/_core/views/document.py +25 -0
- kiarina/lib/redisearch/_core/views/info_result.py +24 -0
- kiarina/lib/redisearch/_core/views/search_result.py +31 -0
- kiarina/lib/redisearch/_sync/__init__.py +0 -0
- kiarina/lib/redisearch/_sync/client.py +179 -0
- kiarina/lib/redisearch/_sync/registry.py +16 -0
- kiarina/lib/redisearch/asyncio.py +33 -0
- kiarina/lib/redisearch/filter/__init__.py +61 -0
- kiarina/lib/redisearch/filter/_decorators.py +28 -0
- kiarina/lib/redisearch/filter/_enums.py +28 -0
- kiarina/lib/redisearch/filter/_field/__init__.py +5 -0
- kiarina/lib/redisearch/filter/_field/base.py +67 -0
- kiarina/lib/redisearch/filter/_field/numeric.py +178 -0
- kiarina/lib/redisearch/filter/_field/tag.py +142 -0
- kiarina/lib/redisearch/filter/_field/text.py +111 -0
- kiarina/lib/redisearch/filter/_model.py +93 -0
- kiarina/lib/redisearch/filter/_registry.py +153 -0
- kiarina/lib/redisearch/filter/_types.py +32 -0
- kiarina/lib/redisearch/filter/_utils.py +18 -0
- kiarina/lib/redisearch/py.typed +0 -0
- kiarina/lib/redisearch/schema/__init__.py +25 -0
- kiarina/lib/redisearch/schema/_field/__init__.py +0 -0
- kiarina/lib/redisearch/schema/_field/base.py +20 -0
- kiarina/lib/redisearch/schema/_field/numeric.py +33 -0
- kiarina/lib/redisearch/schema/_field/tag.py +46 -0
- kiarina/lib/redisearch/schema/_field/text.py +44 -0
- kiarina/lib/redisearch/schema/_field/vector/__init__.py +0 -0
- kiarina/lib/redisearch/schema/_field/vector/base.py +61 -0
- kiarina/lib/redisearch/schema/_field/vector/flat.py +40 -0
- kiarina/lib/redisearch/schema/_field/vector/hnsw.py +53 -0
- kiarina/lib/redisearch/schema/_model.py +98 -0
- kiarina/lib/redisearch/schema/_types.py +16 -0
- kiarina/lib/redisearch/settings.py +47 -0
- kiarina_lib_redisearch-1.0.0.dist-info/METADATA +886 -0
- kiarina_lib_redisearch-1.0.0.dist-info/RECORD +62 -0
- kiarina_lib_redisearch-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Awaitable, Literal, overload
|
2
|
+
|
3
|
+
from ..context import RedisearchContext
|
4
|
+
from ..utils.unmarshal_mappings import unmarshal_mappings
|
5
|
+
from ..views.document import Document
|
6
|
+
from .get_key import get_key
|
7
|
+
|
8
|
+
|
9
|
+
@overload
|
10
|
+
def get(
|
11
|
+
mode: Literal["sync"],
|
12
|
+
ctx: RedisearchContext,
|
13
|
+
id: str,
|
14
|
+
) -> Document | None: ...
|
15
|
+
|
16
|
+
|
17
|
+
@overload
|
18
|
+
def get(
|
19
|
+
mode: Literal["async"],
|
20
|
+
ctx: RedisearchContext,
|
21
|
+
id: str,
|
22
|
+
) -> Awaitable[Document | None]: ...
|
23
|
+
|
24
|
+
|
25
|
+
def get(
|
26
|
+
mode: Literal["sync", "async"],
|
27
|
+
ctx: RedisearchContext,
|
28
|
+
id: str,
|
29
|
+
) -> Document | None | Awaitable[Document | None]:
|
30
|
+
"""
|
31
|
+
Get a document from the index.
|
32
|
+
"""
|
33
|
+
key = get_key(ctx, id)
|
34
|
+
|
35
|
+
def _after(mapping: dict[bytes, bytes]) -> Document | None:
|
36
|
+
if not mapping:
|
37
|
+
return None
|
38
|
+
|
39
|
+
unmarshaled = unmarshal_mappings(schema=ctx.schema, mapping=mapping)
|
40
|
+
|
41
|
+
return Document(
|
42
|
+
key=key,
|
43
|
+
id=id,
|
44
|
+
mapping=unmarshaled,
|
45
|
+
)
|
46
|
+
|
47
|
+
def _sync() -> Document | None:
|
48
|
+
mapping = ctx.redis.hgetall(key)
|
49
|
+
assert isinstance(mapping, dict)
|
50
|
+
return _after(mapping)
|
51
|
+
|
52
|
+
async def _async() -> Document | None:
|
53
|
+
coro = ctx.redis_async.hgetall(key)
|
54
|
+
assert not isinstance(coro, dict)
|
55
|
+
mapping = await coro
|
56
|
+
return _after(mapping)
|
57
|
+
|
58
|
+
if mode == "sync":
|
59
|
+
return _sync()
|
60
|
+
else:
|
61
|
+
return _async()
|
@@ -0,0 +1,155 @@
|
|
1
|
+
from typing import Any, Awaitable, Literal, overload
|
2
|
+
|
3
|
+
from ...schema import RedisearchSchema
|
4
|
+
from ..context import RedisearchContext
|
5
|
+
from ..views.info_result import InfoResult
|
6
|
+
|
7
|
+
|
8
|
+
@overload
|
9
|
+
def get_info(
|
10
|
+
mode: Literal["sync"],
|
11
|
+
ctx: RedisearchContext,
|
12
|
+
) -> InfoResult: ...
|
13
|
+
|
14
|
+
|
15
|
+
@overload
|
16
|
+
def get_info(
|
17
|
+
mode: Literal["async"],
|
18
|
+
ctx: RedisearchContext,
|
19
|
+
) -> Awaitable[InfoResult]: ...
|
20
|
+
|
21
|
+
|
22
|
+
def get_info(
|
23
|
+
mode: Literal["sync", "async"],
|
24
|
+
ctx: RedisearchContext,
|
25
|
+
) -> InfoResult | Awaitable[InfoResult]:
|
26
|
+
"""
|
27
|
+
Get index information using FT.INFO command.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def _after(result: dict[str, Any]) -> InfoResult:
|
31
|
+
return InfoResult(
|
32
|
+
index_name=str(result.get("index_name", "")),
|
33
|
+
num_docs=int(result.get("num_docs", 0)),
|
34
|
+
num_terms=int(result.get("num_terms", 0)),
|
35
|
+
num_records=int(result.get("num_records", 0)),
|
36
|
+
index_schema=_parse_schema(ctx.schema, result),
|
37
|
+
)
|
38
|
+
|
39
|
+
def _sync() -> InfoResult:
|
40
|
+
result = ctx.redis.ft(index_name=ctx.settings.index_name).info() # type: ignore[no-untyped-call]
|
41
|
+
assert isinstance(result, dict)
|
42
|
+
return _after(result)
|
43
|
+
|
44
|
+
async def _async() -> InfoResult:
|
45
|
+
result = await ctx.redis_async.ft(index_name=ctx.settings.index_name).info() # type: ignore[no-untyped-call]
|
46
|
+
assert isinstance(result, dict)
|
47
|
+
return _after(result)
|
48
|
+
|
49
|
+
if mode == "sync":
|
50
|
+
return _sync()
|
51
|
+
else:
|
52
|
+
return _async()
|
53
|
+
|
54
|
+
|
55
|
+
def _parse_schema(schema: RedisearchSchema, result: dict[str, Any]) -> RedisearchSchema:
|
56
|
+
"""
|
57
|
+
Parse the schema information from the FT.INFO results
|
58
|
+
"""
|
59
|
+
fields: list[dict[str, Any]] = []
|
60
|
+
|
61
|
+
if "attributes" not in result:
|
62
|
+
raise ValueError("The FT.INFO results do not contain attributes.")
|
63
|
+
|
64
|
+
for attr in result["attributes"]:
|
65
|
+
attr_dict = _parse_attribute(attr)
|
66
|
+
field = _parse_field(attr_dict)
|
67
|
+
fields.append(field)
|
68
|
+
|
69
|
+
return RedisearchSchema.from_field_dicts(fields)
|
70
|
+
|
71
|
+
|
72
|
+
def _parse_attribute(attr: Any) -> dict[str, Any]:
|
73
|
+
attr_dict = {}
|
74
|
+
|
75
|
+
for i in range(0, len(attr), 2):
|
76
|
+
key = attr[i].decode("utf-8") if isinstance(attr[i], bytes) else attr[i]
|
77
|
+
|
78
|
+
if i + 1 >= len(attr):
|
79
|
+
break
|
80
|
+
|
81
|
+
value = attr[i + 1]
|
82
|
+
|
83
|
+
if isinstance(value, bytes):
|
84
|
+
value = value.decode("utf-8")
|
85
|
+
|
86
|
+
elif isinstance(value, list):
|
87
|
+
value = [v.decode("utf-8") if isinstance(v, bytes) else v for v in value]
|
88
|
+
|
89
|
+
attr_dict[key] = value
|
90
|
+
|
91
|
+
return attr_dict
|
92
|
+
|
93
|
+
|
94
|
+
def _parse_field(attr_dict: dict[str, Any]) -> dict[str, Any]:
|
95
|
+
field_type = _get_field_type(attr_dict)
|
96
|
+
|
97
|
+
field_dict: dict[str, Any] = {}
|
98
|
+
field_dict["name"] = str(attr_dict.get("identifier"))
|
99
|
+
|
100
|
+
if field_type == "tag":
|
101
|
+
return _parse_tag_field(field_dict, attr_dict)
|
102
|
+
elif field_type == "numeric":
|
103
|
+
return _parse_numeric_field(field_dict, attr_dict)
|
104
|
+
elif field_type == "text":
|
105
|
+
return _parse_text_field(field_dict, attr_dict)
|
106
|
+
elif field_type == "vector":
|
107
|
+
return _parse_vector_field(field_dict, attr_dict)
|
108
|
+
else:
|
109
|
+
raise ValueError(f"Unknown field type: {field_type}")
|
110
|
+
|
111
|
+
|
112
|
+
def _parse_tag_field(
|
113
|
+
field_dict: dict[str, Any], attr_dict: dict[str, Any]
|
114
|
+
) -> dict[str, Any]:
|
115
|
+
field_dict["separator"] = str(attr_dict.get("SEPARATOR", ","))
|
116
|
+
field_dict["case_sensitive"] = "CASE_SENSITIVE" in attr_dict
|
117
|
+
field_dict["no_index"] = "NO_INDEX" in attr_dict
|
118
|
+
field_dict["sortable"] = "SORTABLE" in attr_dict
|
119
|
+
return field_dict
|
120
|
+
|
121
|
+
|
122
|
+
def _parse_numeric_field(
|
123
|
+
field_dict: dict[str, Any], attr_dict: dict[str, Any]
|
124
|
+
) -> dict[str, Any]:
|
125
|
+
field_dict["no_index"] = "NO_INDEX" in attr_dict
|
126
|
+
field_dict["sortable"] = "SORTABLE" in attr_dict
|
127
|
+
return field_dict
|
128
|
+
|
129
|
+
|
130
|
+
def _parse_text_field(
|
131
|
+
field_dict: dict[str, Any], attr_dict: dict[str, Any]
|
132
|
+
) -> dict[str, Any]:
|
133
|
+
field_dict["weight"] = float(attr_dict.get("WEIGHT", 1.0))
|
134
|
+
field_dict["no_stem"] = "NO_STEM" in attr_dict
|
135
|
+
field_dict["withsuffixtrie"] = "WITHSUFFIX" in attr_dict
|
136
|
+
field_dict["no_index"] = "NO_INDEX" in attr_dict
|
137
|
+
field_dict["sortable"] = "SORTABLE" in attr_dict
|
138
|
+
return field_dict
|
139
|
+
|
140
|
+
|
141
|
+
def _parse_vector_field(
|
142
|
+
field_dict: dict[str, Any], attr_dict: dict[str, Any]
|
143
|
+
) -> dict[str, Any]:
|
144
|
+
field_dict["dims"] = int(attr_dict.get("dim", 0))
|
145
|
+
field_dict["algorithm"] = str(attr_dict.get("algorithm", ""))
|
146
|
+
field_dict["datatype"] = str(attr_dict.get("data_type", ""))
|
147
|
+
field_dict["distance_metric"] = str(attr_dict.get("distance_metric", ""))
|
148
|
+
return field_dict
|
149
|
+
|
150
|
+
|
151
|
+
def _get_field_type(attr_dict: dict[str, Any]) -> str:
|
152
|
+
if "type" not in attr_dict:
|
153
|
+
raise ValueError("The FT.INFO results do not include the field type.")
|
154
|
+
|
155
|
+
return str(attr_dict["type"]).lower()
|
@@ -0,0 +1,160 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Awaitable, Literal, overload
|
3
|
+
|
4
|
+
from ...schema import RedisearchSchema
|
5
|
+
from ..context import RedisearchContext
|
6
|
+
from .create_index import create_index
|
7
|
+
from .drop_index import drop_index
|
8
|
+
from .exists_index import exists_index
|
9
|
+
from .get_info import get_info
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
@overload
|
15
|
+
def migrate_index(
|
16
|
+
mode: Literal["sync"],
|
17
|
+
ctx: RedisearchContext,
|
18
|
+
) -> bool: ...
|
19
|
+
|
20
|
+
|
21
|
+
@overload
|
22
|
+
def migrate_index(
|
23
|
+
mode: Literal["async"],
|
24
|
+
ctx: RedisearchContext,
|
25
|
+
) -> Awaitable[bool]: ...
|
26
|
+
|
27
|
+
|
28
|
+
def migrate_index(
|
29
|
+
mode: Literal["sync", "async"],
|
30
|
+
ctx: RedisearchContext,
|
31
|
+
) -> bool | Awaitable[bool]:
|
32
|
+
"""
|
33
|
+
Reset the search index.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def _log_create_new_index() -> None:
|
37
|
+
logger.info("Createing new index '%s'", ctx.settings.index_name)
|
38
|
+
|
39
|
+
def _log_no_schema_changes() -> None:
|
40
|
+
logger.info("No schema changes detected, migration not needed.")
|
41
|
+
|
42
|
+
def _log_migration_needed(diffs: dict[str, tuple[Any, Any]]) -> None:
|
43
|
+
logger.info("Schema changes detected, migration needed:")
|
44
|
+
|
45
|
+
for path, (old, new) in diffs.items():
|
46
|
+
logger.info(" - %s: %r -> %r", path, old, new)
|
47
|
+
|
48
|
+
def _log_delete_index() -> None:
|
49
|
+
logger.info(
|
50
|
+
"Deleting existing index '%s', data will be re-indexed",
|
51
|
+
ctx.settings.index_name,
|
52
|
+
)
|
53
|
+
|
54
|
+
def _sync() -> bool:
|
55
|
+
if not exists_index(mode="sync", ctx=ctx):
|
56
|
+
_log_create_new_index()
|
57
|
+
create_index(mode="sync", ctx=ctx)
|
58
|
+
return True
|
59
|
+
|
60
|
+
info_result = get_info(mode="sync", ctx=ctx)
|
61
|
+
diffs = _check_schema_changes(current=info_result.index_schema, new=ctx.schema)
|
62
|
+
|
63
|
+
if not diffs:
|
64
|
+
_log_no_schema_changes()
|
65
|
+
return False
|
66
|
+
|
67
|
+
_log_migration_needed(diffs)
|
68
|
+
|
69
|
+
_log_delete_index()
|
70
|
+
drop_index(mode="sync", ctx=ctx, delete_documents=False)
|
71
|
+
|
72
|
+
_log_create_new_index()
|
73
|
+
create_index(mode="sync", ctx=ctx)
|
74
|
+
return True
|
75
|
+
|
76
|
+
async def _async() -> bool:
|
77
|
+
if not await exists_index(mode="async", ctx=ctx):
|
78
|
+
_log_create_new_index()
|
79
|
+
await create_index(mode="async", ctx=ctx)
|
80
|
+
return True
|
81
|
+
|
82
|
+
info_result = await get_info(mode="async", ctx=ctx)
|
83
|
+
diffs = _check_schema_changes(current=info_result.index_schema, new=ctx.schema)
|
84
|
+
|
85
|
+
if not diffs:
|
86
|
+
_log_no_schema_changes()
|
87
|
+
return False
|
88
|
+
|
89
|
+
_log_migration_needed(diffs)
|
90
|
+
|
91
|
+
_log_delete_index()
|
92
|
+
await drop_index(mode="async", ctx=ctx, delete_documents=False)
|
93
|
+
|
94
|
+
_log_create_new_index()
|
95
|
+
await create_index(mode="async", ctx=ctx)
|
96
|
+
return True
|
97
|
+
|
98
|
+
if mode == "sync":
|
99
|
+
return _sync()
|
100
|
+
else:
|
101
|
+
return _async()
|
102
|
+
|
103
|
+
|
104
|
+
def _check_schema_changes(
|
105
|
+
current: RedisearchSchema,
|
106
|
+
new: RedisearchSchema,
|
107
|
+
) -> dict[str, tuple[Any, Any]]:
|
108
|
+
if current == new:
|
109
|
+
return {}
|
110
|
+
|
111
|
+
return _diff_dict(
|
112
|
+
current.model_dump(),
|
113
|
+
new.model_dump(),
|
114
|
+
)
|
115
|
+
|
116
|
+
|
117
|
+
def _diff_dict(
|
118
|
+
d1: dict[str, Any], d2: dict[str, Any], prefix: str = ""
|
119
|
+
) -> dict[str, tuple[Any, Any]]:
|
120
|
+
diffs: dict[str, tuple[Any, Any]] = {}
|
121
|
+
|
122
|
+
keys = set(d1.keys()) | set(d2.keys())
|
123
|
+
|
124
|
+
for k in keys:
|
125
|
+
v1, v2 = d1.get(k), d2.get(k)
|
126
|
+
path = f"{prefix}.{k}" if prefix else k
|
127
|
+
|
128
|
+
# Nested dict
|
129
|
+
if isinstance(v1, dict) and isinstance(v2, dict):
|
130
|
+
nested_diff = _diff_dict(v1, v2, prefix=path)
|
131
|
+
diffs.update(nested_diff)
|
132
|
+
|
133
|
+
# Nested list
|
134
|
+
elif isinstance(v1, list) and isinstance(v2, list):
|
135
|
+
max_len = max(len(v1), len(v2))
|
136
|
+
|
137
|
+
for i in range(max_len):
|
138
|
+
p = f"{path}[{i}]"
|
139
|
+
|
140
|
+
try:
|
141
|
+
item1, item2 = v1[i], v2[i]
|
142
|
+
except IndexError:
|
143
|
+
diffs[p] = (
|
144
|
+
v1[i] if i < len(v1) else None,
|
145
|
+
v2[i] if i < len(v2) else None,
|
146
|
+
)
|
147
|
+
continue
|
148
|
+
|
149
|
+
if isinstance(item1, dict) and isinstance(item2, dict):
|
150
|
+
nested_diff = _diff_dict(item1, item2, prefix=p)
|
151
|
+
diffs.update(nested_diff)
|
152
|
+
|
153
|
+
elif item1 != item2:
|
154
|
+
diffs[p] = (item1, item2)
|
155
|
+
|
156
|
+
# Different values
|
157
|
+
elif v1 != v2:
|
158
|
+
diffs[path] = (v1, v2)
|
159
|
+
|
160
|
+
return diffs
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Awaitable, Literal, overload
|
3
|
+
|
4
|
+
from ..context import RedisearchContext
|
5
|
+
from .create_index import create_index
|
6
|
+
from .drop_index import drop_index
|
7
|
+
from .exists_index import exists_index
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
@overload
|
13
|
+
def reset_index(
|
14
|
+
mode: Literal["sync"],
|
15
|
+
ctx: RedisearchContext,
|
16
|
+
) -> None: ...
|
17
|
+
|
18
|
+
|
19
|
+
@overload
|
20
|
+
def reset_index(
|
21
|
+
mode: Literal["async"],
|
22
|
+
ctx: RedisearchContext,
|
23
|
+
) -> Awaitable[None]: ...
|
24
|
+
|
25
|
+
|
26
|
+
def reset_index(
|
27
|
+
mode: Literal["sync", "async"],
|
28
|
+
ctx: RedisearchContext,
|
29
|
+
) -> None | Awaitable[None]:
|
30
|
+
"""
|
31
|
+
Reset the search index.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def _log_delete_index() -> None:
|
35
|
+
logger.info("Deleting existing index '%s'", ctx.settings.index_name)
|
36
|
+
|
37
|
+
def _log_create_index() -> None:
|
38
|
+
logger.info("Creating new index '%s'", ctx.settings.index_name)
|
39
|
+
|
40
|
+
def _sync() -> None:
|
41
|
+
if exists_index(mode="sync", ctx=ctx):
|
42
|
+
_log_delete_index()
|
43
|
+
drop_index(mode="sync", ctx=ctx, delete_documents=True)
|
44
|
+
|
45
|
+
_log_create_index()
|
46
|
+
create_index(mode="sync", ctx=ctx)
|
47
|
+
|
48
|
+
async def _async() -> None:
|
49
|
+
if await exists_index(mode="async", ctx=ctx):
|
50
|
+
_log_delete_index()
|
51
|
+
await drop_index(mode="async", ctx=ctx, delete_documents=True)
|
52
|
+
|
53
|
+
_log_create_index()
|
54
|
+
await create_index(mode="async", ctx=ctx)
|
55
|
+
|
56
|
+
if mode == "sync":
|
57
|
+
_sync()
|
58
|
+
return None
|
59
|
+
else:
|
60
|
+
return _async()
|
@@ -0,0 +1,111 @@
|
|
1
|
+
from typing import Any, Awaitable, Literal, overload
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from redis.commands.search.query import Query
|
5
|
+
from redis.commands.search.result import Result
|
6
|
+
|
7
|
+
from ...filter import (
|
8
|
+
RedisearchFilter,
|
9
|
+
RedisearchFilterConditions,
|
10
|
+
create_redisearch_filter,
|
11
|
+
)
|
12
|
+
from ..context import RedisearchContext
|
13
|
+
from ..utils.parse_search_result import parse_search_result
|
14
|
+
from ..views.search_result import SearchResult
|
15
|
+
from .count import count
|
16
|
+
|
17
|
+
|
18
|
+
@overload
|
19
|
+
def search(
|
20
|
+
mode: Literal["sync"],
|
21
|
+
ctx: RedisearchContext,
|
22
|
+
vector: list[float],
|
23
|
+
filter: RedisearchFilter | RedisearchFilterConditions | None = None,
|
24
|
+
offset: int | None = None,
|
25
|
+
limit: int | None = None,
|
26
|
+
return_fields: list[str] | None = None,
|
27
|
+
) -> SearchResult: ...
|
28
|
+
|
29
|
+
|
30
|
+
@overload
|
31
|
+
def search(
|
32
|
+
mode: Literal["async"],
|
33
|
+
ctx: RedisearchContext,
|
34
|
+
vector: list[float],
|
35
|
+
filter: RedisearchFilter | RedisearchFilterConditions | None = None,
|
36
|
+
offset: int | None = None,
|
37
|
+
limit: int | None = None,
|
38
|
+
return_fields: list[str] | None = None,
|
39
|
+
) -> Awaitable[SearchResult]: ...
|
40
|
+
|
41
|
+
|
42
|
+
def search(
|
43
|
+
mode: Literal["sync", "async"],
|
44
|
+
ctx: RedisearchContext,
|
45
|
+
vector: list[float],
|
46
|
+
filter: RedisearchFilter | RedisearchFilterConditions | None = None,
|
47
|
+
offset: int | None = None,
|
48
|
+
limit: int | None = None,
|
49
|
+
return_fields: list[str] | None = None,
|
50
|
+
) -> SearchResult | Awaitable[SearchResult]:
|
51
|
+
"""
|
52
|
+
Search documents using vector similarity search.
|
53
|
+
"""
|
54
|
+
# filter_query
|
55
|
+
if filter is not None:
|
56
|
+
filter = create_redisearch_filter(filter=filter, schema=ctx.schema)
|
57
|
+
|
58
|
+
filter_query = "*" if filter is None else str(filter)
|
59
|
+
|
60
|
+
# vector_field_name
|
61
|
+
vector_field_name = ctx.schema.vector_field.name
|
62
|
+
|
63
|
+
# return_fields
|
64
|
+
return_fields = return_fields or []
|
65
|
+
|
66
|
+
if "distance" not in return_fields:
|
67
|
+
return_fields.append("distance")
|
68
|
+
|
69
|
+
# params
|
70
|
+
params: dict[str, str | int | float | bytes] = {
|
71
|
+
"vector": np.array(vector).astype(ctx.schema.vector_field.dtype).tobytes()
|
72
|
+
}
|
73
|
+
|
74
|
+
def _build_query(limit: int) -> Query:
|
75
|
+
query = Query(
|
76
|
+
f"({filter_query})=>[KNN {limit} @{vector_field_name} $vector AS distance]"
|
77
|
+
)
|
78
|
+
|
79
|
+
if return_fields:
|
80
|
+
query = query.return_fields(*return_fields)
|
81
|
+
else:
|
82
|
+
query = query.no_content()
|
83
|
+
|
84
|
+
query = query.sort_by("distance")
|
85
|
+
query = query.paging(offset or 0, limit)
|
86
|
+
|
87
|
+
return query
|
88
|
+
|
89
|
+
def _parse_search_result(result: Any) -> SearchResult:
|
90
|
+
assert isinstance(result, Result)
|
91
|
+
return parse_search_result(
|
92
|
+
key_prefix=ctx.settings.key_prefix,
|
93
|
+
schema=ctx.schema,
|
94
|
+
return_fields=return_fields,
|
95
|
+
result=result,
|
96
|
+
)
|
97
|
+
|
98
|
+
def _sync() -> SearchResult:
|
99
|
+
query = _build_query(limit or count("sync", ctx, filter).total)
|
100
|
+
result = ctx.redis.ft(ctx.settings.index_name).search(query, params)
|
101
|
+
return _parse_search_result(result)
|
102
|
+
|
103
|
+
async def _async() -> SearchResult:
|
104
|
+
query = _build_query(limit or (await count("async", ctx, filter)).total)
|
105
|
+
result = await ctx.redis_async.ft(ctx.settings.index_name).search(query, params) # type: ignore
|
106
|
+
return _parse_search_result(result)
|
107
|
+
|
108
|
+
if mode == "sync":
|
109
|
+
return _sync()
|
110
|
+
else:
|
111
|
+
return _async()
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from typing import Any, Awaitable, Literal, overload
|
2
|
+
|
3
|
+
from ..context import RedisearchContext
|
4
|
+
from ..utils.marshal_mappings import marshal_mappings
|
5
|
+
from .get_key import get_key
|
6
|
+
|
7
|
+
|
8
|
+
@overload
|
9
|
+
def set(
|
10
|
+
mode: Literal["sync"],
|
11
|
+
ctx: RedisearchContext,
|
12
|
+
mapping: dict[str, Any],
|
13
|
+
*,
|
14
|
+
id: str | None = None,
|
15
|
+
) -> None: ...
|
16
|
+
|
17
|
+
|
18
|
+
@overload
|
19
|
+
def set(
|
20
|
+
mode: Literal["async"],
|
21
|
+
ctx: RedisearchContext,
|
22
|
+
mapping: dict[str, Any],
|
23
|
+
*,
|
24
|
+
id: str | None = None,
|
25
|
+
) -> Awaitable[None]: ...
|
26
|
+
|
27
|
+
|
28
|
+
def set(
|
29
|
+
mode: Literal["sync", "async"],
|
30
|
+
ctx: RedisearchContext,
|
31
|
+
mapping: dict[str, Any],
|
32
|
+
*,
|
33
|
+
id: str | None = None,
|
34
|
+
) -> None | Awaitable[None]:
|
35
|
+
"""
|
36
|
+
Set a document in the index.
|
37
|
+
|
38
|
+
Fields not present in the schema are saved as they are.
|
39
|
+
Fields present in the schema are converted to the appropriate type and stored.
|
40
|
+
"""
|
41
|
+
if id is None:
|
42
|
+
if "id" not in mapping:
|
43
|
+
raise ValueError(
|
44
|
+
'Either "id" parameter or "id" field in mapping must be provided.'
|
45
|
+
)
|
46
|
+
|
47
|
+
id = str(mapping.get("id"))
|
48
|
+
|
49
|
+
key = get_key(ctx, id)
|
50
|
+
|
51
|
+
mapping = marshal_mappings(schema=ctx.schema, mapping=mapping)
|
52
|
+
|
53
|
+
def _sync() -> None:
|
54
|
+
ctx.redis.hset(key, mapping=mapping)
|
55
|
+
|
56
|
+
async def _async() -> None:
|
57
|
+
coro = ctx.redis_async.hset(key, mapping=mapping)
|
58
|
+
assert not isinstance(coro, int)
|
59
|
+
await coro
|
60
|
+
|
61
|
+
if mode == "sync":
|
62
|
+
_sync()
|
63
|
+
return None
|
64
|
+
else:
|
65
|
+
return _async()
|
File without changes
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Literal
|
3
|
+
|
4
|
+
|
5
|
+
def calc_score(
|
6
|
+
distance: float,
|
7
|
+
*,
|
8
|
+
datatype: Literal["FLOAT32", "FLOAT64"],
|
9
|
+
distance_metric: Literal["COSINE", "IP", "L2"],
|
10
|
+
) -> float:
|
11
|
+
"""
|
12
|
+
Calculate relevance score from distance.
|
13
|
+
"""
|
14
|
+
if datatype == "FLOAT32":
|
15
|
+
distance = round(distance, 4)
|
16
|
+
else:
|
17
|
+
distance = round(distance, 7)
|
18
|
+
|
19
|
+
if distance_metric == "COSINE":
|
20
|
+
# Normalise the cosine distance to a score within the range [0, 1]
|
21
|
+
return 1.0 - distance
|
22
|
+
|
23
|
+
elif distance_metric == "IP":
|
24
|
+
# Normalise the inner product distance to a score within the range [0, 1]
|
25
|
+
if distance > 0:
|
26
|
+
return 1.0 - distance
|
27
|
+
else:
|
28
|
+
return -1.0 * distance
|
29
|
+
|
30
|
+
elif distance_metric == "L2":
|
31
|
+
# Convert the Euclidean distance to a similarity score within the range [0, 1]
|
32
|
+
return 1.0 - distance / math.sqrt(2)
|
33
|
+
|
34
|
+
else:
|
35
|
+
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|