kiarina-lib-redisearch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. kiarina/lib/redisearch/__init__.py +35 -0
  2. kiarina/lib/redisearch/_async/__init__.py +0 -0
  3. kiarina/lib/redisearch/_async/client.py +181 -0
  4. kiarina/lib/redisearch/_async/registry.py +16 -0
  5. kiarina/lib/redisearch/_core/__init__.py +0 -0
  6. kiarina/lib/redisearch/_core/context.py +69 -0
  7. kiarina/lib/redisearch/_core/operations/__init__.py +0 -0
  8. kiarina/lib/redisearch/_core/operations/count.py +55 -0
  9. kiarina/lib/redisearch/_core/operations/create_index.py +52 -0
  10. kiarina/lib/redisearch/_core/operations/delete.py +43 -0
  11. kiarina/lib/redisearch/_core/operations/drop_index.py +59 -0
  12. kiarina/lib/redisearch/_core/operations/exists_index.py +56 -0
  13. kiarina/lib/redisearch/_core/operations/find.py +105 -0
  14. kiarina/lib/redisearch/_core/operations/get.py +61 -0
  15. kiarina/lib/redisearch/_core/operations/get_info.py +155 -0
  16. kiarina/lib/redisearch/_core/operations/get_key.py +8 -0
  17. kiarina/lib/redisearch/_core/operations/migrate_index.py +160 -0
  18. kiarina/lib/redisearch/_core/operations/reset_index.py +60 -0
  19. kiarina/lib/redisearch/_core/operations/search.py +111 -0
  20. kiarina/lib/redisearch/_core/operations/set.py +65 -0
  21. kiarina/lib/redisearch/_core/utils/__init__.py +0 -0
  22. kiarina/lib/redisearch/_core/utils/calc_score.py +35 -0
  23. kiarina/lib/redisearch/_core/utils/marshal_mappings.py +57 -0
  24. kiarina/lib/redisearch/_core/utils/parse_search_result.py +57 -0
  25. kiarina/lib/redisearch/_core/utils/unmarshal_mappings.py +57 -0
  26. kiarina/lib/redisearch/_core/views/__init__.py +0 -0
  27. kiarina/lib/redisearch/_core/views/document.py +25 -0
  28. kiarina/lib/redisearch/_core/views/info_result.py +24 -0
  29. kiarina/lib/redisearch/_core/views/search_result.py +31 -0
  30. kiarina/lib/redisearch/_sync/__init__.py +0 -0
  31. kiarina/lib/redisearch/_sync/client.py +179 -0
  32. kiarina/lib/redisearch/_sync/registry.py +16 -0
  33. kiarina/lib/redisearch/asyncio.py +33 -0
  34. kiarina/lib/redisearch/filter/__init__.py +61 -0
  35. kiarina/lib/redisearch/filter/_decorators.py +28 -0
  36. kiarina/lib/redisearch/filter/_enums.py +28 -0
  37. kiarina/lib/redisearch/filter/_field/__init__.py +5 -0
  38. kiarina/lib/redisearch/filter/_field/base.py +67 -0
  39. kiarina/lib/redisearch/filter/_field/numeric.py +178 -0
  40. kiarina/lib/redisearch/filter/_field/tag.py +142 -0
  41. kiarina/lib/redisearch/filter/_field/text.py +111 -0
  42. kiarina/lib/redisearch/filter/_model.py +93 -0
  43. kiarina/lib/redisearch/filter/_registry.py +153 -0
  44. kiarina/lib/redisearch/filter/_types.py +32 -0
  45. kiarina/lib/redisearch/filter/_utils.py +18 -0
  46. kiarina/lib/redisearch/py.typed +0 -0
  47. kiarina/lib/redisearch/schema/__init__.py +25 -0
  48. kiarina/lib/redisearch/schema/_field/__init__.py +0 -0
  49. kiarina/lib/redisearch/schema/_field/base.py +20 -0
  50. kiarina/lib/redisearch/schema/_field/numeric.py +33 -0
  51. kiarina/lib/redisearch/schema/_field/tag.py +46 -0
  52. kiarina/lib/redisearch/schema/_field/text.py +44 -0
  53. kiarina/lib/redisearch/schema/_field/vector/__init__.py +0 -0
  54. kiarina/lib/redisearch/schema/_field/vector/base.py +61 -0
  55. kiarina/lib/redisearch/schema/_field/vector/flat.py +40 -0
  56. kiarina/lib/redisearch/schema/_field/vector/hnsw.py +53 -0
  57. kiarina/lib/redisearch/schema/_model.py +98 -0
  58. kiarina/lib/redisearch/schema/_types.py +16 -0
  59. kiarina/lib/redisearch/settings.py +47 -0
  60. kiarina_lib_redisearch-1.0.0.dist-info/METADATA +886 -0
  61. kiarina_lib_redisearch-1.0.0.dist-info/RECORD +62 -0
  62. kiarina_lib_redisearch-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,61 @@
1
+ from typing import Awaitable, Literal, overload
2
+
3
+ from ..context import RedisearchContext
4
+ from ..utils.unmarshal_mappings import unmarshal_mappings
5
+ from ..views.document import Document
6
+ from .get_key import get_key
7
+
8
+
9
+ @overload
10
+ def get(
11
+ mode: Literal["sync"],
12
+ ctx: RedisearchContext,
13
+ id: str,
14
+ ) -> Document | None: ...
15
+
16
+
17
+ @overload
18
+ def get(
19
+ mode: Literal["async"],
20
+ ctx: RedisearchContext,
21
+ id: str,
22
+ ) -> Awaitable[Document | None]: ...
23
+
24
+
25
+ def get(
26
+ mode: Literal["sync", "async"],
27
+ ctx: RedisearchContext,
28
+ id: str,
29
+ ) -> Document | None | Awaitable[Document | None]:
30
+ """
31
+ Get a document from the index.
32
+ """
33
+ key = get_key(ctx, id)
34
+
35
+ def _after(mapping: dict[bytes, bytes]) -> Document | None:
36
+ if not mapping:
37
+ return None
38
+
39
+ unmarshaled = unmarshal_mappings(schema=ctx.schema, mapping=mapping)
40
+
41
+ return Document(
42
+ key=key,
43
+ id=id,
44
+ mapping=unmarshaled,
45
+ )
46
+
47
+ def _sync() -> Document | None:
48
+ mapping = ctx.redis.hgetall(key)
49
+ assert isinstance(mapping, dict)
50
+ return _after(mapping)
51
+
52
+ async def _async() -> Document | None:
53
+ coro = ctx.redis_async.hgetall(key)
54
+ assert not isinstance(coro, dict)
55
+ mapping = await coro
56
+ return _after(mapping)
57
+
58
+ if mode == "sync":
59
+ return _sync()
60
+ else:
61
+ return _async()
@@ -0,0 +1,155 @@
1
+ from typing import Any, Awaitable, Literal, overload
2
+
3
+ from ...schema import RedisearchSchema
4
+ from ..context import RedisearchContext
5
+ from ..views.info_result import InfoResult
6
+
7
+
8
+ @overload
9
+ def get_info(
10
+ mode: Literal["sync"],
11
+ ctx: RedisearchContext,
12
+ ) -> InfoResult: ...
13
+
14
+
15
+ @overload
16
+ def get_info(
17
+ mode: Literal["async"],
18
+ ctx: RedisearchContext,
19
+ ) -> Awaitable[InfoResult]: ...
20
+
21
+
22
+ def get_info(
23
+ mode: Literal["sync", "async"],
24
+ ctx: RedisearchContext,
25
+ ) -> InfoResult | Awaitable[InfoResult]:
26
+ """
27
+ Get index information using FT.INFO command.
28
+ """
29
+
30
+ def _after(result: dict[str, Any]) -> InfoResult:
31
+ return InfoResult(
32
+ index_name=str(result.get("index_name", "")),
33
+ num_docs=int(result.get("num_docs", 0)),
34
+ num_terms=int(result.get("num_terms", 0)),
35
+ num_records=int(result.get("num_records", 0)),
36
+ index_schema=_parse_schema(ctx.schema, result),
37
+ )
38
+
39
+ def _sync() -> InfoResult:
40
+ result = ctx.redis.ft(index_name=ctx.settings.index_name).info() # type: ignore[no-untyped-call]
41
+ assert isinstance(result, dict)
42
+ return _after(result)
43
+
44
+ async def _async() -> InfoResult:
45
+ result = await ctx.redis_async.ft(index_name=ctx.settings.index_name).info() # type: ignore[no-untyped-call]
46
+ assert isinstance(result, dict)
47
+ return _after(result)
48
+
49
+ if mode == "sync":
50
+ return _sync()
51
+ else:
52
+ return _async()
53
+
54
+
55
+ def _parse_schema(schema: RedisearchSchema, result: dict[str, Any]) -> RedisearchSchema:
56
+ """
57
+ Parse the schema information from the FT.INFO results
58
+ """
59
+ fields: list[dict[str, Any]] = []
60
+
61
+ if "attributes" not in result:
62
+ raise ValueError("The FT.INFO results do not contain attributes.")
63
+
64
+ for attr in result["attributes"]:
65
+ attr_dict = _parse_attribute(attr)
66
+ field = _parse_field(attr_dict)
67
+ fields.append(field)
68
+
69
+ return RedisearchSchema.from_field_dicts(fields)
70
+
71
+
72
+ def _parse_attribute(attr: Any) -> dict[str, Any]:
73
+ attr_dict = {}
74
+
75
+ for i in range(0, len(attr), 2):
76
+ key = attr[i].decode("utf-8") if isinstance(attr[i], bytes) else attr[i]
77
+
78
+ if i + 1 >= len(attr):
79
+ break
80
+
81
+ value = attr[i + 1]
82
+
83
+ if isinstance(value, bytes):
84
+ value = value.decode("utf-8")
85
+
86
+ elif isinstance(value, list):
87
+ value = [v.decode("utf-8") if isinstance(v, bytes) else v for v in value]
88
+
89
+ attr_dict[key] = value
90
+
91
+ return attr_dict
92
+
93
+
94
+ def _parse_field(attr_dict: dict[str, Any]) -> dict[str, Any]:
95
+ field_type = _get_field_type(attr_dict)
96
+
97
+ field_dict: dict[str, Any] = {}
98
+ field_dict["name"] = str(attr_dict.get("identifier"))
99
+
100
+ if field_type == "tag":
101
+ return _parse_tag_field(field_dict, attr_dict)
102
+ elif field_type == "numeric":
103
+ return _parse_numeric_field(field_dict, attr_dict)
104
+ elif field_type == "text":
105
+ return _parse_text_field(field_dict, attr_dict)
106
+ elif field_type == "vector":
107
+ return _parse_vector_field(field_dict, attr_dict)
108
+ else:
109
+ raise ValueError(f"Unknown field type: {field_type}")
110
+
111
+
112
+ def _parse_tag_field(
113
+ field_dict: dict[str, Any], attr_dict: dict[str, Any]
114
+ ) -> dict[str, Any]:
115
+ field_dict["separator"] = str(attr_dict.get("SEPARATOR", ","))
116
+ field_dict["case_sensitive"] = "CASE_SENSITIVE" in attr_dict
117
+ field_dict["no_index"] = "NO_INDEX" in attr_dict
118
+ field_dict["sortable"] = "SORTABLE" in attr_dict
119
+ return field_dict
120
+
121
+
122
+ def _parse_numeric_field(
123
+ field_dict: dict[str, Any], attr_dict: dict[str, Any]
124
+ ) -> dict[str, Any]:
125
+ field_dict["no_index"] = "NO_INDEX" in attr_dict
126
+ field_dict["sortable"] = "SORTABLE" in attr_dict
127
+ return field_dict
128
+
129
+
130
+ def _parse_text_field(
131
+ field_dict: dict[str, Any], attr_dict: dict[str, Any]
132
+ ) -> dict[str, Any]:
133
+ field_dict["weight"] = float(attr_dict.get("WEIGHT", 1.0))
134
+ field_dict["no_stem"] = "NO_STEM" in attr_dict
135
+ field_dict["withsuffixtrie"] = "WITHSUFFIX" in attr_dict
136
+ field_dict["no_index"] = "NO_INDEX" in attr_dict
137
+ field_dict["sortable"] = "SORTABLE" in attr_dict
138
+ return field_dict
139
+
140
+
141
+ def _parse_vector_field(
142
+ field_dict: dict[str, Any], attr_dict: dict[str, Any]
143
+ ) -> dict[str, Any]:
144
+ field_dict["dims"] = int(attr_dict.get("dim", 0))
145
+ field_dict["algorithm"] = str(attr_dict.get("algorithm", ""))
146
+ field_dict["datatype"] = str(attr_dict.get("data_type", ""))
147
+ field_dict["distance_metric"] = str(attr_dict.get("distance_metric", ""))
148
+ return field_dict
149
+
150
+
151
+ def _get_field_type(attr_dict: dict[str, Any]) -> str:
152
+ if "type" not in attr_dict:
153
+ raise ValueError("The FT.INFO results do not include the field type.")
154
+
155
+ return str(attr_dict["type"]).lower()
@@ -0,0 +1,8 @@
1
+ from ..context import RedisearchContext
2
+
3
+
4
+ def get_key(ctx: RedisearchContext, id: str) -> str:
5
+ """
6
+ Get the Redis key for a given Redisearch ID.
7
+ """
8
+ return f"{ctx.settings.key_prefix}{id}"
@@ -0,0 +1,160 @@
1
+ import logging
2
+ from typing import Any, Awaitable, Literal, overload
3
+
4
+ from ...schema import RedisearchSchema
5
+ from ..context import RedisearchContext
6
+ from .create_index import create_index
7
+ from .drop_index import drop_index
8
+ from .exists_index import exists_index
9
+ from .get_info import get_info
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @overload
15
+ def migrate_index(
16
+ mode: Literal["sync"],
17
+ ctx: RedisearchContext,
18
+ ) -> bool: ...
19
+
20
+
21
+ @overload
22
+ def migrate_index(
23
+ mode: Literal["async"],
24
+ ctx: RedisearchContext,
25
+ ) -> Awaitable[bool]: ...
26
+
27
+
28
+ def migrate_index(
29
+ mode: Literal["sync", "async"],
30
+ ctx: RedisearchContext,
31
+ ) -> bool | Awaitable[bool]:
32
+ """
33
+ Reset the search index.
34
+ """
35
+
36
+ def _log_create_new_index() -> None:
37
+ logger.info("Createing new index '%s'", ctx.settings.index_name)
38
+
39
+ def _log_no_schema_changes() -> None:
40
+ logger.info("No schema changes detected, migration not needed.")
41
+
42
+ def _log_migration_needed(diffs: dict[str, tuple[Any, Any]]) -> None:
43
+ logger.info("Schema changes detected, migration needed:")
44
+
45
+ for path, (old, new) in diffs.items():
46
+ logger.info(" - %s: %r -> %r", path, old, new)
47
+
48
+ def _log_delete_index() -> None:
49
+ logger.info(
50
+ "Deleting existing index '%s', data will be re-indexed",
51
+ ctx.settings.index_name,
52
+ )
53
+
54
+ def _sync() -> bool:
55
+ if not exists_index(mode="sync", ctx=ctx):
56
+ _log_create_new_index()
57
+ create_index(mode="sync", ctx=ctx)
58
+ return True
59
+
60
+ info_result = get_info(mode="sync", ctx=ctx)
61
+ diffs = _check_schema_changes(current=info_result.index_schema, new=ctx.schema)
62
+
63
+ if not diffs:
64
+ _log_no_schema_changes()
65
+ return False
66
+
67
+ _log_migration_needed(diffs)
68
+
69
+ _log_delete_index()
70
+ drop_index(mode="sync", ctx=ctx, delete_documents=False)
71
+
72
+ _log_create_new_index()
73
+ create_index(mode="sync", ctx=ctx)
74
+ return True
75
+
76
+ async def _async() -> bool:
77
+ if not await exists_index(mode="async", ctx=ctx):
78
+ _log_create_new_index()
79
+ await create_index(mode="async", ctx=ctx)
80
+ return True
81
+
82
+ info_result = await get_info(mode="async", ctx=ctx)
83
+ diffs = _check_schema_changes(current=info_result.index_schema, new=ctx.schema)
84
+
85
+ if not diffs:
86
+ _log_no_schema_changes()
87
+ return False
88
+
89
+ _log_migration_needed(diffs)
90
+
91
+ _log_delete_index()
92
+ await drop_index(mode="async", ctx=ctx, delete_documents=False)
93
+
94
+ _log_create_new_index()
95
+ await create_index(mode="async", ctx=ctx)
96
+ return True
97
+
98
+ if mode == "sync":
99
+ return _sync()
100
+ else:
101
+ return _async()
102
+
103
+
104
+ def _check_schema_changes(
105
+ current: RedisearchSchema,
106
+ new: RedisearchSchema,
107
+ ) -> dict[str, tuple[Any, Any]]:
108
+ if current == new:
109
+ return {}
110
+
111
+ return _diff_dict(
112
+ current.model_dump(),
113
+ new.model_dump(),
114
+ )
115
+
116
+
117
+ def _diff_dict(
118
+ d1: dict[str, Any], d2: dict[str, Any], prefix: str = ""
119
+ ) -> dict[str, tuple[Any, Any]]:
120
+ diffs: dict[str, tuple[Any, Any]] = {}
121
+
122
+ keys = set(d1.keys()) | set(d2.keys())
123
+
124
+ for k in keys:
125
+ v1, v2 = d1.get(k), d2.get(k)
126
+ path = f"{prefix}.{k}" if prefix else k
127
+
128
+ # Nested dict
129
+ if isinstance(v1, dict) and isinstance(v2, dict):
130
+ nested_diff = _diff_dict(v1, v2, prefix=path)
131
+ diffs.update(nested_diff)
132
+
133
+ # Nested list
134
+ elif isinstance(v1, list) and isinstance(v2, list):
135
+ max_len = max(len(v1), len(v2))
136
+
137
+ for i in range(max_len):
138
+ p = f"{path}[{i}]"
139
+
140
+ try:
141
+ item1, item2 = v1[i], v2[i]
142
+ except IndexError:
143
+ diffs[p] = (
144
+ v1[i] if i < len(v1) else None,
145
+ v2[i] if i < len(v2) else None,
146
+ )
147
+ continue
148
+
149
+ if isinstance(item1, dict) and isinstance(item2, dict):
150
+ nested_diff = _diff_dict(item1, item2, prefix=p)
151
+ diffs.update(nested_diff)
152
+
153
+ elif item1 != item2:
154
+ diffs[p] = (item1, item2)
155
+
156
+ # Different values
157
+ elif v1 != v2:
158
+ diffs[path] = (v1, v2)
159
+
160
+ return diffs
@@ -0,0 +1,60 @@
1
+ import logging
2
+ from typing import Awaitable, Literal, overload
3
+
4
+ from ..context import RedisearchContext
5
+ from .create_index import create_index
6
+ from .drop_index import drop_index
7
+ from .exists_index import exists_index
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @overload
13
+ def reset_index(
14
+ mode: Literal["sync"],
15
+ ctx: RedisearchContext,
16
+ ) -> None: ...
17
+
18
+
19
+ @overload
20
+ def reset_index(
21
+ mode: Literal["async"],
22
+ ctx: RedisearchContext,
23
+ ) -> Awaitable[None]: ...
24
+
25
+
26
+ def reset_index(
27
+ mode: Literal["sync", "async"],
28
+ ctx: RedisearchContext,
29
+ ) -> None | Awaitable[None]:
30
+ """
31
+ Reset the search index.
32
+ """
33
+
34
+ def _log_delete_index() -> None:
35
+ logger.info("Deleting existing index '%s'", ctx.settings.index_name)
36
+
37
+ def _log_create_index() -> None:
38
+ logger.info("Creating new index '%s'", ctx.settings.index_name)
39
+
40
+ def _sync() -> None:
41
+ if exists_index(mode="sync", ctx=ctx):
42
+ _log_delete_index()
43
+ drop_index(mode="sync", ctx=ctx, delete_documents=True)
44
+
45
+ _log_create_index()
46
+ create_index(mode="sync", ctx=ctx)
47
+
48
+ async def _async() -> None:
49
+ if await exists_index(mode="async", ctx=ctx):
50
+ _log_delete_index()
51
+ await drop_index(mode="async", ctx=ctx, delete_documents=True)
52
+
53
+ _log_create_index()
54
+ await create_index(mode="async", ctx=ctx)
55
+
56
+ if mode == "sync":
57
+ _sync()
58
+ return None
59
+ else:
60
+ return _async()
@@ -0,0 +1,111 @@
1
+ from typing import Any, Awaitable, Literal, overload
2
+
3
+ import numpy as np
4
+ from redis.commands.search.query import Query
5
+ from redis.commands.search.result import Result
6
+
7
+ from ...filter import (
8
+ RedisearchFilter,
9
+ RedisearchFilterConditions,
10
+ create_redisearch_filter,
11
+ )
12
+ from ..context import RedisearchContext
13
+ from ..utils.parse_search_result import parse_search_result
14
+ from ..views.search_result import SearchResult
15
+ from .count import count
16
+
17
+
18
+ @overload
19
+ def search(
20
+ mode: Literal["sync"],
21
+ ctx: RedisearchContext,
22
+ vector: list[float],
23
+ filter: RedisearchFilter | RedisearchFilterConditions | None = None,
24
+ offset: int | None = None,
25
+ limit: int | None = None,
26
+ return_fields: list[str] | None = None,
27
+ ) -> SearchResult: ...
28
+
29
+
30
+ @overload
31
+ def search(
32
+ mode: Literal["async"],
33
+ ctx: RedisearchContext,
34
+ vector: list[float],
35
+ filter: RedisearchFilter | RedisearchFilterConditions | None = None,
36
+ offset: int | None = None,
37
+ limit: int | None = None,
38
+ return_fields: list[str] | None = None,
39
+ ) -> Awaitable[SearchResult]: ...
40
+
41
+
42
+ def search(
43
+ mode: Literal["sync", "async"],
44
+ ctx: RedisearchContext,
45
+ vector: list[float],
46
+ filter: RedisearchFilter | RedisearchFilterConditions | None = None,
47
+ offset: int | None = None,
48
+ limit: int | None = None,
49
+ return_fields: list[str] | None = None,
50
+ ) -> SearchResult | Awaitable[SearchResult]:
51
+ """
52
+ Search documents using vector similarity search.
53
+ """
54
+ # filter_query
55
+ if filter is not None:
56
+ filter = create_redisearch_filter(filter=filter, schema=ctx.schema)
57
+
58
+ filter_query = "*" if filter is None else str(filter)
59
+
60
+ # vector_field_name
61
+ vector_field_name = ctx.schema.vector_field.name
62
+
63
+ # return_fields
64
+ return_fields = return_fields or []
65
+
66
+ if "distance" not in return_fields:
67
+ return_fields.append("distance")
68
+
69
+ # params
70
+ params: dict[str, str | int | float | bytes] = {
71
+ "vector": np.array(vector).astype(ctx.schema.vector_field.dtype).tobytes()
72
+ }
73
+
74
+ def _build_query(limit: int) -> Query:
75
+ query = Query(
76
+ f"({filter_query})=>[KNN {limit} @{vector_field_name} $vector AS distance]"
77
+ )
78
+
79
+ if return_fields:
80
+ query = query.return_fields(*return_fields)
81
+ else:
82
+ query = query.no_content()
83
+
84
+ query = query.sort_by("distance")
85
+ query = query.paging(offset or 0, limit)
86
+
87
+ return query
88
+
89
+ def _parse_search_result(result: Any) -> SearchResult:
90
+ assert isinstance(result, Result)
91
+ return parse_search_result(
92
+ key_prefix=ctx.settings.key_prefix,
93
+ schema=ctx.schema,
94
+ return_fields=return_fields,
95
+ result=result,
96
+ )
97
+
98
+ def _sync() -> SearchResult:
99
+ query = _build_query(limit or count("sync", ctx, filter).total)
100
+ result = ctx.redis.ft(ctx.settings.index_name).search(query, params)
101
+ return _parse_search_result(result)
102
+
103
+ async def _async() -> SearchResult:
104
+ query = _build_query(limit or (await count("async", ctx, filter)).total)
105
+ result = await ctx.redis_async.ft(ctx.settings.index_name).search(query, params) # type: ignore
106
+ return _parse_search_result(result)
107
+
108
+ if mode == "sync":
109
+ return _sync()
110
+ else:
111
+ return _async()
@@ -0,0 +1,65 @@
1
+ from typing import Any, Awaitable, Literal, overload
2
+
3
+ from ..context import RedisearchContext
4
+ from ..utils.marshal_mappings import marshal_mappings
5
+ from .get_key import get_key
6
+
7
+
8
+ @overload
9
+ def set(
10
+ mode: Literal["sync"],
11
+ ctx: RedisearchContext,
12
+ mapping: dict[str, Any],
13
+ *,
14
+ id: str | None = None,
15
+ ) -> None: ...
16
+
17
+
18
+ @overload
19
+ def set(
20
+ mode: Literal["async"],
21
+ ctx: RedisearchContext,
22
+ mapping: dict[str, Any],
23
+ *,
24
+ id: str | None = None,
25
+ ) -> Awaitable[None]: ...
26
+
27
+
28
+ def set(
29
+ mode: Literal["sync", "async"],
30
+ ctx: RedisearchContext,
31
+ mapping: dict[str, Any],
32
+ *,
33
+ id: str | None = None,
34
+ ) -> None | Awaitable[None]:
35
+ """
36
+ Set a document in the index.
37
+
38
+ Fields not present in the schema are saved as they are.
39
+ Fields present in the schema are converted to the appropriate type and stored.
40
+ """
41
+ if id is None:
42
+ if "id" not in mapping:
43
+ raise ValueError(
44
+ 'Either "id" parameter or "id" field in mapping must be provided.'
45
+ )
46
+
47
+ id = str(mapping.get("id"))
48
+
49
+ key = get_key(ctx, id)
50
+
51
+ mapping = marshal_mappings(schema=ctx.schema, mapping=mapping)
52
+
53
+ def _sync() -> None:
54
+ ctx.redis.hset(key, mapping=mapping)
55
+
56
+ async def _async() -> None:
57
+ coro = ctx.redis_async.hset(key, mapping=mapping)
58
+ assert not isinstance(coro, int)
59
+ await coro
60
+
61
+ if mode == "sync":
62
+ _sync()
63
+ return None
64
+ else:
65
+ return _async()
File without changes
@@ -0,0 +1,35 @@
1
+ import math
2
+ from typing import Literal
3
+
4
+
5
+ def calc_score(
6
+ distance: float,
7
+ *,
8
+ datatype: Literal["FLOAT32", "FLOAT64"],
9
+ distance_metric: Literal["COSINE", "IP", "L2"],
10
+ ) -> float:
11
+ """
12
+ Calculate relevance score from distance.
13
+ """
14
+ if datatype == "FLOAT32":
15
+ distance = round(distance, 4)
16
+ else:
17
+ distance = round(distance, 7)
18
+
19
+ if distance_metric == "COSINE":
20
+ # Normalise the cosine distance to a score within the range [0, 1]
21
+ return 1.0 - distance
22
+
23
+ elif distance_metric == "IP":
24
+ # Normalise the inner product distance to a score within the range [0, 1]
25
+ if distance > 0:
26
+ return 1.0 - distance
27
+ else:
28
+ return -1.0 * distance
29
+
30
+ elif distance_metric == "L2":
31
+ # Convert the Euclidean distance to a similarity score within the range [0, 1]
32
+ return 1.0 - distance / math.sqrt(2)
33
+
34
+ else:
35
+ raise ValueError(f"Unsupported distance metric: {distance_metric}")