nvidia-nat-redis 1.3.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nat/meta/pypi.md ADDED
@@ -0,0 +1,23 @@
1
+ <!--
2
+ SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ SPDX-License-Identifier: Apache-2.0
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ -->
17
+
18
+ ![NVIDIA NeMo Agent Toolkit](https://media.githubusercontent.com/media/NVIDIA/NeMo-Agent-Toolkit/refs/heads/main/docs/source/_static/banner.png "NeMo Agent toolkit banner image")
19
+
20
+ # NVIDIA NeMo Agent Toolkit Subpackage
21
+ This is a subpackage for Redis memory integration in NeMo Agent toolkit.
22
+
23
+ For more information about NeMo Agent toolkit, please visit the [NeMo Agent toolkit package](https://pypi.org/project/nvidia-nat/).
File without changes
@@ -0,0 +1,58 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+
18
+ from nat.builder.builder import Builder
19
+ from nat.cli.register_workflow import register_memory
20
+ from nat.data_models.component_ref import EmbedderRef
21
+ from nat.data_models.memory import MemoryBaseConfig
22
+
23
+
24
+ class RedisMemoryClientConfig(MemoryBaseConfig, name="redis_memory"):
25
+ host: str = Field(default="localhost", description="Redis server host")
26
+ db: int = Field(default=0, description="Redis DB")
27
+ port: int = Field(default=6379, description="Redis server port")
28
+ key_prefix: str = Field(default="nat", description="Key prefix to use for redis keys")
29
+ embedder: EmbedderRef = Field(description=("Instance name of the memory client instance from the workflow "
30
+ "configuration object."))
31
+
32
+
33
+ @register_memory(config_type=RedisMemoryClientConfig)
34
+ async def redis_memory_client(config: RedisMemoryClientConfig, builder: Builder):
35
+
36
+ import redis.asyncio as redis
37
+
38
+ from nat.builder.framework_enum import LLMFrameworkEnum
39
+ from nat.plugins.redis.redis_editor import RedisEditor
40
+
41
+ from .schema import ensure_index_exists
42
+
43
+ redis_client = redis.Redis(host=config.host,
44
+ port=config.port,
45
+ db=config.db,
46
+ decode_responses=True,
47
+ socket_timeout=5.0,
48
+ socket_connect_timeout=5.0)
49
+
50
+ embedder = await builder.get_embedder(config.embedder, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
51
+
52
+ test_embedding = await embedder.aembed_query("test")
53
+ embedding_dim = len(test_embedding)
54
+ await ensure_index_exists(client=redis_client, key_prefix=config.key_prefix, embedding_dim=embedding_dim)
55
+
56
+ memory_editor = RedisEditor(redis_client=redis_client, key_prefix=config.key_prefix, embedder=embedder)
57
+
58
+ yield memory_editor
@@ -0,0 +1,40 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+
18
+ from nat.builder.builder import Builder
19
+ from nat.cli.register_workflow import register_object_store
20
+ from nat.data_models.object_store import ObjectStoreBaseConfig
21
+
22
+
23
+ class RedisObjectStoreClientConfig(ObjectStoreBaseConfig, name="redis"):
24
+ """
25
+ Object store that stores objects in a Redis database.
26
+ """
27
+
28
+ host: str = Field(default="localhost", description="The host of the Redis server")
29
+ db: int = Field(default=0, description="The Redis logical database number")
30
+ port: int = Field(default=6379, description="The port of the Redis server")
31
+ bucket_name: str = Field(description="The name of the bucket to use for the object store")
32
+
33
+
34
+ @register_object_store(config_type=RedisObjectStoreClientConfig)
35
+ async def redis_object_store_client(config: RedisObjectStoreClientConfig, _builder: Builder):
36
+
37
+ from .redis_object_store import RedisObjectStore
38
+
39
+ async with RedisObjectStore(**config.model_dump(exclude={"type"})) as store:
40
+ yield store
@@ -0,0 +1,228 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import secrets
18
+
19
+ import numpy as np
20
+ import redis.asyncio as redis
21
+ import redis.exceptions as redis_exceptions
22
+ from langchain_core.embeddings import Embeddings
23
+ from redis.commands.search.query import Query
24
+
25
+ from nat.memory.interfaces import MemoryEditor
26
+ from nat.memory.models import MemoryItem
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ INDEX_NAME = "memory_idx"
31
+
32
+
33
+ class RedisEditor(MemoryEditor):
34
+ """
35
+ Wrapper class that implements NAT interfaces for Redis memory storage.
36
+ """
37
+
38
+ def __init__(self, redis_client: redis.Redis, key_prefix: str, embedder: Embeddings):
39
+ """
40
+ Initialize Redis client for memory storage.
41
+
42
+ Args:
43
+ redis_client: (redis.Redis) Redis client
44
+ key_prefix: (str) Redis key prefix
45
+ embedder: (Embeddings) Embedder for semantic search functionality
46
+ """
47
+
48
+ self._client: redis.Redis = redis_client
49
+ self._key_prefix: str = key_prefix
50
+ self._embedder: Embeddings = embedder
51
+
52
+ async def add_items(self, items: list[MemoryItem]) -> None:
53
+ """
54
+ Insert Multiple MemoryItems into Redis.
55
+ Each MemoryItem is stored with its metadata and tags.
56
+ """
57
+ logger.debug("Attempting to add %d items to Redis", len(items))
58
+
59
+ for memory_item in items:
60
+ item_meta = memory_item.metadata
61
+ conversation = memory_item.conversation
62
+ user_id = memory_item.user_id
63
+ tags = memory_item.tags
64
+ memory_id = secrets.token_hex(4) # e.g. 02ba3fe9
65
+
66
+ # Create a unique key for this memory item
67
+ memory_key = f"{self._key_prefix}:memory:{memory_id}"
68
+ logger.debug("Generated memory key: %s", memory_key)
69
+
70
+ # Prepare memory data
71
+ memory_data = {
72
+ "conversation": conversation,
73
+ "user_id": user_id,
74
+ "tags": tags,
75
+ "metadata": item_meta,
76
+ "memory": memory_item.memory or ""
77
+ }
78
+ logger.debug("Prepared memory data for key %s", memory_key)
79
+
80
+ # If we have memory, compute and store the embedding
81
+ if memory_item.memory:
82
+ logger.debug("Computing embedding for memory text")
83
+ search_vector = await self._embedder.aembed_query(memory_item.memory)
84
+ logger.debug("Generated embedding vector of length: %d", len(search_vector))
85
+ memory_data["embedding"] = search_vector
86
+
87
+ try:
88
+ # Store as JSON in Redis
89
+ logger.debug("Attempting to store memory data in Redis for key: %s", memory_key)
90
+ await self._client.json().set(memory_key, "$", memory_data)
91
+ logger.debug("Successfully stored memory data for key: %s", memory_key)
92
+
93
+ # Verify the data was stored
94
+ stored_data = await self._client.json().get(memory_key)
95
+ logger.debug("Verified data storage for key %s: %s", memory_key, bool(stored_data))
96
+
97
+ except redis_exceptions.ResponseError as e:
98
+ logger.error("Failed to store memory item: %s", e)
99
+ raise
100
+ except redis_exceptions.ConnectionError as e:
101
+ logger.error("Redis connection error while storing memory item: %s", e)
102
+ raise
103
+
104
+ async def search(self, query: str, top_k: int = 5, **kwargs) -> list[MemoryItem]:
105
+ """
106
+ Retrieve items relevant to the given query.
107
+
108
+ Args:
109
+ query (str): The query string to match.
110
+ top_k (int): Maximum number of items to return.
111
+ kwargs (dict): Keyword arguments to pass to the search method.
112
+
113
+ Returns:
114
+ list[MemoryItem]: The most relevant MemoryItems for the given query.
115
+ """
116
+ logger.debug("Search called with query: %s, top_k: %d, kwargs: %s", query, top_k, kwargs)
117
+
118
+ user_id = kwargs.get("user_id", "redis") # TODO: remove this fallback username
119
+ logger.debug("Using user_id: %s", user_id)
120
+
121
+ # Perform vector search using Redis search
122
+ logger.debug("Using embedder for vector search")
123
+ try:
124
+ logger.debug("Generating embedding for query: '%s'", query)
125
+ query_vector = await self._embedder.aembed_query(query)
126
+ logger.debug("Generated embedding vector of length: %d", len(query_vector))
127
+ except Exception as e:
128
+ logger.error("Failed to generate embedding: %s", e)
129
+ raise
130
+
131
+ # Create vector search query
132
+ search_query = (
133
+ Query(f"(@user_id:{user_id})=>[KNN {top_k} @embedding $vec AS score]").sort_by("score").return_fields(
134
+ "conversation", "user_id", "tags", "metadata", "memory", "score").dialect(2))
135
+ logger.debug("Created search query: %s", search_query)
136
+ logger.debug("Query string: %s", search_query.query_string())
137
+
138
+ # Convert query vector to bytes
139
+ try:
140
+ logger.debug("Converting query vector to bytes")
141
+ query_vector_bytes = np.array(query_vector, dtype=np.float32).tobytes()
142
+ logger.debug("Converted vector to bytes of length: %d", len(query_vector_bytes))
143
+ except Exception as e:
144
+ logger.error("Failed to convert vector to bytes: %s", e)
145
+ raise
146
+
147
+ try:
148
+ # Execute search with vector parameters
149
+ logger.debug("Executing Redis search with vector parameters")
150
+ logger.debug("Search query parameters: vec length=%d", len(query_vector_bytes))
151
+
152
+ # Log the actual query being executed
153
+ logger.debug("Full search query: %s", search_query.query_string())
154
+
155
+ # Check if there are any documents in the index
156
+ try:
157
+ total_docs = await self._client.ft(INDEX_NAME).info()
158
+ logger.debug("Total documents in index: %d", total_docs.get('num_docs', 0))
159
+ except Exception as e:
160
+ logger.exception("Failed to get index info: %s", e)
161
+
162
+ # Execute the search
163
+ results = await self._client.ft(INDEX_NAME).search(search_query, query_params={"vec": query_vector_bytes})
164
+
165
+ # Log detailed results information
166
+ logger.debug("Search returned %d results", len(results.docs))
167
+ logger.debug("Total results found: %d", results.total)
168
+
169
+ # Convert results to MemoryItems
170
+ memories = []
171
+ for i, doc in enumerate(results.docs):
172
+ try:
173
+ logger.debug("Processing result %d/%d", i + 1, len(results.docs))
174
+ logger.debug("Similarity score: %d", getattr(doc, 'score', 0))
175
+
176
+ # Get the full document data
177
+ full_doc = await self._client.json().get(doc.id)
178
+ logger.debug("Extracted data for result %d: %s", i + 1, full_doc)
179
+ memory_item = self._create_memory_item(dict(full_doc), user_id)
180
+ memories.append(memory_item)
181
+ logger.debug("Successfully created MemoryItem for result %d", i + 1)
182
+ except Exception as e:
183
+ logger.error("Failed to process result %d: %s", i + 1, e)
184
+ raise
185
+
186
+ logger.debug("Successfully processed all %d results", len(memories))
187
+ return memories
188
+ except redis_exceptions.ResponseError as e:
189
+ logger.error("Search failed with ResponseError: %s", e)
190
+ raise
191
+ except redis_exceptions.ConnectionError as e:
192
+ logger.error("Search failed with ConnectionError: %s", e)
193
+ raise
194
+ except Exception as e:
195
+ logger.error("Unexpected error during search: %s", e)
196
+ raise
197
+
198
+ def _create_memory_item(self, memory_data: dict, user_id: str) -> MemoryItem:
199
+ """Helper method to create a MemoryItem from Redis data."""
200
+ # Ensure tags is always a list
201
+ tags = memory_data.get("tags", [])
202
+ # Not sure why but sometimes the tags are retrieved as a string
203
+ if isinstance(tags, str):
204
+ tags = [tags]
205
+ elif not isinstance(tags, list):
206
+ tags = []
207
+
208
+ return MemoryItem(conversation=memory_data.get("conversation", []),
209
+ user_id=user_id,
210
+ memory=memory_data.get("memory", ""),
211
+ tags=tags,
212
+ metadata=memory_data.get("metadata", {}))
213
+
214
+ async def remove_items(self, **kwargs):
215
+ """
216
+ Remove memory items based on provided criteria.
217
+ """
218
+ try:
219
+ pattern = f"{self._key_prefix}:memory:*"
220
+ keys = await self._client.keys(pattern)
221
+ if keys:
222
+ await self._client.delete(*keys)
223
+ except redis_exceptions.ResponseError as e:
224
+ logger.error("Failed to remove items: %s", e)
225
+ raise
226
+ except redis_exceptions.ConnectionError as e:
227
+ logger.error("Redis connection error while removing items: %s", e)
228
+ raise
@@ -0,0 +1,126 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ import redis.asyncio as redis
19
+
20
+ from nat.data_models.object_store import KeyAlreadyExistsError
21
+ from nat.data_models.object_store import NoSuchKeyError
22
+ from nat.object_store.interfaces import ObjectStore
23
+ from nat.object_store.models import ObjectStoreItem
24
+ from nat.utils.type_utils import override
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class RedisObjectStore(ObjectStore):
30
+ """
31
+ Implementation of ObjectStore that stores objects in Redis.
32
+
33
+ Each object is stored as a single binary value at key "nat/object_store/{bucket_name}/{object_key}".
34
+ """
35
+
36
+ def __init__(self, *, bucket_name: str, host: str, port: int, db: int):
37
+
38
+ super().__init__()
39
+
40
+ self._bucket_name = bucket_name
41
+ self._host = host
42
+ self._port = port
43
+ self._db = db
44
+ self._client: redis.Redis | None = None
45
+
46
+ async def __aenter__(self) -> "RedisObjectStore":
47
+
48
+ if self._client is not None:
49
+ raise RuntimeError("Connection already established")
50
+
51
+ self._client = redis.Redis(
52
+ host=self._host,
53
+ port=self._port,
54
+ db=self._db,
55
+ socket_timeout=5.0,
56
+ socket_connect_timeout=5.0,
57
+ )
58
+
59
+ # Ping to ensure connectivity
60
+ res = await self._client.ping()
61
+ if not res:
62
+ raise RuntimeError("Failed to connect to Redis")
63
+
64
+ logger.info("Connected Redis client for %s at %s:%s/%s", self._bucket_name, self._host, self._port, self._db)
65
+
66
+ return self
67
+
68
+ async def __aexit__(self, exc_type, exc_value, traceback) -> None:
69
+
70
+ if not self._client:
71
+ raise RuntimeError("Connection not established")
72
+
73
+ await self._client.close()
74
+ self._client = None
75
+
76
+ def _make_key(self, key: str) -> str:
77
+ return f"nat/object_store/{self._bucket_name}/{key}"
78
+
79
+ @override
80
+ async def put_object(self, key: str, item: ObjectStoreItem):
81
+
82
+ if not self._client:
83
+ raise RuntimeError("Connection not established")
84
+
85
+ full_key = self._make_key(key)
86
+
87
+ item_json = item.model_dump_json()
88
+ # Redis SET with NX ensures we do not overwrite existing keys
89
+ if not await self._client.set(full_key, item_json, nx=True):
90
+ raise KeyAlreadyExistsError(key=key,
91
+ additional_message=f"Redis bucket {self._bucket_name} already has key {key}")
92
+
93
+ @override
94
+ async def upsert_object(self, key: str, item: ObjectStoreItem):
95
+
96
+ if not self._client:
97
+ raise RuntimeError("Connection not established")
98
+
99
+ full_key = self._make_key(key)
100
+ item_json = item.model_dump_json()
101
+ await self._client.set(full_key, item_json)
102
+
103
+ @override
104
+ async def get_object(self, key: str) -> ObjectStoreItem:
105
+
106
+ if not self._client:
107
+ raise RuntimeError("Connection not established")
108
+
109
+ full_key = self._make_key(key)
110
+ data = await self._client.get(full_key)
111
+ if data is None:
112
+ raise NoSuchKeyError(key=key,
113
+ additional_message=f"Redis bucket {self._bucket_name} does not have key {key}")
114
+ return ObjectStoreItem.model_validate_json(data)
115
+
116
+ @override
117
+ async def delete_object(self, key: str):
118
+
119
+ if not self._client:
120
+ raise RuntimeError("Connection not established")
121
+
122
+ full_key = self._make_key(key)
123
+ deleted = await self._client.delete(full_key)
124
+ if deleted == 0:
125
+ raise NoSuchKeyError(key=key,
126
+ additional_message=f"Redis bucket {self._bucket_name} does not have key {key}")
@@ -0,0 +1,22 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+ # isort:skip_file
18
+
19
+ # Import any providers which need to be automatically registered here
20
+
21
+ from . import memory
22
+ from . import object_store
@@ -0,0 +1,136 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ import redis.asyncio as redis
19
+ import redis.exceptions as redis_exceptions
20
+ from redis.commands.search.field import TagField
21
+ from redis.commands.search.field import TextField
22
+ from redis.commands.search.field import VectorField
23
+ from redis.commands.search.indexDefinition import IndexDefinition
24
+ from redis.commands.search.indexDefinition import IndexType
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ INDEX_NAME = "memory_idx"
29
+ DEFAULT_DIM = 384 # Default embedding dimension
30
+
31
+
32
+ def create_schema(embedding_dim: int = DEFAULT_DIM):
33
+ """
34
+ Create the Redis search schema for redis_memory.
35
+
36
+ Args:
37
+ embedding_dim (int): Dimension of the embedding vectors
38
+
39
+ Returns:
40
+ tuple: Schema definition for Redis search
41
+ """
42
+ logger.info("Creating schema with embedding dimension: %d", embedding_dim)
43
+
44
+ embedding_field = VectorField("$.embedding",
45
+ "HNSW",
46
+ {
47
+ "TYPE": "FLOAT32",
48
+ "DIM": embedding_dim,
49
+ "DISTANCE_METRIC": "L2",
50
+ "INITIAL_CAP": 100,
51
+ "M": 16,
52
+ "EF_CONSTRUCTION": 200,
53
+ "EF_RUNTIME": 10
54
+ },
55
+ as_name="embedding")
56
+ logger.info("Created embedding field with dimension %d", embedding_dim)
57
+
58
+ schema = (
59
+ # Redis search can't directly index complex objects (e.g. conversation and metadata) in return_fields
60
+ # They need to be retrieved via json().get() for full object access
61
+ TextField("$.user_id", as_name="user_id"),
62
+ TagField("$.tags[*]", as_name="tags"),
63
+ TextField("$.memory", as_name="memory"),
64
+ embedding_field)
65
+
66
+ # Log the schema details
67
+ logger.info("Schema fields:")
68
+ for field in schema:
69
+ logger.info(" - %s: %s", field.name, type(field).__name__)
70
+
71
+ return schema
72
+
73
+
74
+ async def ensure_index_exists(client: redis.Redis, key_prefix: str, embedding_dim: int | None) -> None:
75
+ """
76
+ Ensure the Redis search index exists, creating it if necessary.
77
+
78
+ Args:
79
+ client (redis.Redis): Redis client instance
80
+ key_prefix (str): Prefix for keys to be indexed
81
+ embedding_dim (Optional[int]): Dimension of embedding vectors. If None, uses default.
82
+ """
83
+ try:
84
+ # Check if index exists
85
+ logger.info("Checking if index '%s' exists...", INDEX_NAME)
86
+ info = await client.ft(INDEX_NAME).info()
87
+ logger.info("Redis search index '%s' exists.", INDEX_NAME)
88
+
89
+ # Verify the schema
90
+ schema = info.get('attributes', [])
91
+
92
+ return
93
+ except redis_exceptions.ResponseError as ex:
94
+ error_msg = str(ex)
95
+ if "no such index" not in error_msg.lower() and "Index needs recreation" not in error_msg:
96
+ logger.error("Unexpected Redis error: %s", error_msg)
97
+ raise
98
+
99
+ # Index doesn't exist or needs recreation
100
+ logger.info("Creating Redis search index '%s' with prefix '%s'", INDEX_NAME, key_prefix)
101
+
102
+ # Drop any existing index
103
+ try:
104
+ logger.info("Attempting to drop existing index '%s' if it exists", INDEX_NAME)
105
+ await client.ft(INDEX_NAME).dropindex()
106
+ logger.info("Successfully dropped existing index '%s'", INDEX_NAME)
107
+ except redis_exceptions.ResponseError as e:
108
+ if "no such index" not in str(e).lower():
109
+ logger.warning("Error while dropping index: %s", str(e))
110
+
111
+ # Create new schema and index
112
+ schema = create_schema(embedding_dim or DEFAULT_DIM)
113
+ logger.info("Created schema with embedding dimension: %d", embedding_dim or DEFAULT_DIM)
114
+
115
+ try:
116
+ # Create the index
117
+ logger.info("Creating new index '%s' with schema", INDEX_NAME)
118
+ await client.ft(INDEX_NAME).create_index(schema,
119
+ definition=IndexDefinition(prefix=[f"{key_prefix}:"],
120
+ index_type=IndexType.JSON))
121
+
122
+ # Verify index was created
123
+ info = await client.ft(INDEX_NAME).info()
124
+ logger.info("Successfully created Redis search index '%s'", INDEX_NAME)
125
+ logger.debug("Redis search index info: %s", info)
126
+
127
+ # Verify the schema
128
+ schema = info.get('attributes', [])
129
+ logger.debug("New index schema: %s", schema)
130
+
131
+ except redis_exceptions.ResponseError as e:
132
+ logger.error("Failed to create index: %s", str(e))
133
+ raise
134
+ except redis_exceptions.ConnectionError as e:
135
+ logger.error("Redis connection error while creating index: %s", str(e))
136
+ raise