sharedkernel 1.4.4__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/PKG-INFO +46 -43
  2. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/README.md +2 -0
  3. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/setup.cfg +4 -4
  4. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/setup.py +40 -39
  5. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/common.py +29 -29
  6. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/database/__init__.py +1 -1
  7. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/database/mongo_generic_repository.py +51 -50
  8. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/database/vector_database_repository/__init__.py +2 -2
  9. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/database/vector_database_repository/chroma_startegy.py +39 -39
  10. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/database/vector_database_repository/milvus_strategy.py +50 -50
  11. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/database/vector_database_repository/vector_database_repository.py +27 -27
  12. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/database/vector_database_repository/vector_database_strategy.py +22 -22
  13. sharedkernel-1.5.0/sharedkernel/date_converter.py +122 -0
  14. sharedkernel-1.5.0/sharedkernel/enum/__init__.py +2 -0
  15. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/enum/error_code.py +14 -14
  16. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/enum/vector_database_type.py +5 -5
  17. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/exception/__init__.py +4 -4
  18. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/exception/exception.py +39 -39
  19. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/exception/exception_handlers.py +75 -75
  20. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/jwt_service.py +38 -38
  21. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/objects/__init__.py +2 -2
  22. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/objects/base_document.py +8 -8
  23. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/objects/jwt_model.py +6 -6
  24. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/objects/result.py +28 -28
  25. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/regex_masking.py +92 -92
  26. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel/string_extentions.py +4 -4
  27. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel.egg-info/PKG-INFO +46 -43
  28. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel.egg-info/requires.txt +2 -1
  29. sharedkernel-1.4.4/sharedkernel/date_converter.py +0 -32
  30. sharedkernel-1.4.4/sharedkernel/enum/__init__.py +0 -2
  31. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel.egg-info/SOURCES.txt +0 -0
  32. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel.egg-info/dependency_links.txt +0 -0
  33. {sharedkernel-1.4.4 → sharedkernel-1.5.0}/sharedkernel.egg-info/top_level.txt +0 -0
@@ -1,43 +1,46 @@
1
- Metadata-Version: 2.1
2
- Name: sharedkernel
3
- Version: 1.4.4
4
- Summary: sharekernel is an shared package between all python projects
5
- Author: Smilinno
6
- Description-Content-Type: text/markdown
7
- Requires-Dist: numpy
8
- Requires-Dist: requests
9
- Requires-Dist: pymongo
10
- Requires-Dist: fastapi==0.89.1
11
- Requires-Dist: PyJWT
12
- Requires-Dist: pymilvus
13
- Requires-Dist: chromadb
14
- Requires-Dist: persian_tools
15
- Requires-Dist: sentry-sdk
16
- Requires-Dist: jdatetime
17
-
18
- # SharedKernel
19
- this a shared kernel package
20
-
21
- # Change Log
22
- ### Version 1.4.4
23
- - Fix regex masking bugs
24
- ### Version 1.4.3
25
- - Fix collection bug in MongoGenericRepository
26
- ### Version 1.4.2
27
- - Fix minor bugs
28
- ### Version 1.4.1
29
- - Fix minor bug in MongoGenericRepository
30
- ### Version 1.4.0
31
- - Implement date convertor for jalali and georgian
32
- ### Version 1.3.0
33
- - Implement Sentry For Log Exceptions
34
- ### Version 1.2.0
35
- - Implement Regex Masking
36
- # Create Package
37
- py -m pip install --upgrade build
38
- py -m build
39
- py -m pip install --upgrade twine
40
- py -m twine upload dist/*
41
-
42
- # Pypi
43
- pip install sharedkernel
1
+ Metadata-Version: 2.1
2
+ Name: sharedkernel
3
+ Version: 1.5.0
4
+ Summary: sharekernel is an shared package between all python projects
5
+ Author: Smilinno
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: numpy
8
+ Requires-Dist: requests
9
+ Requires-Dist: pymongo
10
+ Requires-Dist: fastapi==0.111.0
11
+ Requires-Dist: PyJWT
12
+ Requires-Dist: pymilvus
13
+ Requires-Dist: chromadb
14
+ Requires-Dist: persian_tools
15
+ Requires-Dist: sentry-sdk
16
+ Requires-Dist: jdatetime
17
+ Requires-Dist: persiantools
18
+
19
+ # SharedKernel
20
+ this a shared kernel package
21
+
22
+ # Change Log
23
+ ### Version 1.4.5
24
+ - upgrade fastapi version
25
+ ### Version 1.4.4
26
+ - Fix regex masking bugs
27
+ ### Version 1.4.3
28
+ - Fix collection bug in MongoGenericRepository
29
+ ### Version 1.4.2
30
+ - Fix minor bugs
31
+ ### Version 1.4.1
32
+ - Fix minor bug in MongoGenericRepository
33
+ ### Version 1.4.0
34
+ - Implement date convertor for jalali and georgian
35
+ ### Version 1.3.0
36
+ - Implement Sentry For Log Exceptions
37
+ ### Version 1.2.0
38
+ - Implement Regex Masking
39
+ # Create Package
40
+ py -m pip install --upgrade build
41
+ py -m build
42
+ py -m pip install --upgrade twine
43
+ py -m twine upload dist/*
44
+
45
+ # Pypi
46
+ pip install sharedkernel
@@ -2,6 +2,8 @@
2
2
  this a shared kernel package
3
3
 
4
4
  # Change Log
5
+ ### Version 1.4.5
6
+ - upgrade fastapi version
5
7
  ### Version 1.4.4
6
8
  - Fix regex masking bugs
7
9
  ### Version 1.4.3
@@ -1,4 +1,4 @@
1
- [egg_info]
2
- tag_build =
3
- tag_date = 0
4
-
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -1,39 +1,40 @@
1
- from setuptools import setup
2
-
3
- # read the contents of your README file
4
- from pathlib import Path
5
-
6
- this_directory = Path(__file__).parent
7
- long_description = (this_directory / "README.md").read_text()
8
-
9
- setup(
10
- # Needed to silence warnings (and to be a worthwhile package)
11
- name="sharedkernel",
12
- author="Smilinno",
13
- packages=[
14
- "sharedkernel",
15
- "sharedkernel.database",
16
- "sharedkernel.database.vector_database_repository",
17
- "sharedkernel.enum",
18
- "sharedkernel.exception",
19
- "sharedkernel.objects",
20
- ],
21
- # Needed for dependencies
22
- install_requires=[
23
- "numpy",
24
- "requests",
25
- "pymongo",
26
- "fastapi==0.89.1",
27
- "PyJWT",
28
- "pymilvus",
29
- "chromadb",
30
- "persian_tools",
31
- "sentry-sdk",
32
- "jdatetime"
33
- ],
34
- # *strongly* suggested for sharing
35
- version="1.4.4",
36
- description="sharekernel is an shared package between all python projects",
37
- long_description=long_description,
38
- long_description_content_type="text/markdown",
39
- )
1
+ from setuptools import setup
2
+
3
+ # read the contents of your README file
4
+ from pathlib import Path
5
+
6
+ this_directory = Path(__file__).parent
7
+ long_description = (this_directory / "README.md").read_text()
8
+
9
+ setup(
10
+ # Needed to silence warnings (and to be a worthwhile package)
11
+ name="sharedkernel",
12
+ author="Smilinno",
13
+ packages=[
14
+ "sharedkernel",
15
+ "sharedkernel.database",
16
+ "sharedkernel.database.vector_database_repository",
17
+ "sharedkernel.enum",
18
+ "sharedkernel.exception",
19
+ "sharedkernel.objects",
20
+ ],
21
+ # Needed for dependencies
22
+ install_requires=[
23
+ "numpy",
24
+ "requests",
25
+ "pymongo",
26
+ "fastapi==0.111.0",
27
+ "PyJWT",
28
+ "pymilvus",
29
+ "chromadb",
30
+ "persian_tools",
31
+ "sentry-sdk",
32
+ "jdatetime",
33
+ "persiantools"
34
+ ],
35
+ # *strongly* suggested for sharing
36
+ version="1.5.0",
37
+ description="sharekernel is an shared package between all python projects",
38
+ long_description=long_description,
39
+ long_description_content_type="text/markdown",
40
+ )
@@ -1,29 +1,29 @@
1
- import yaml
2
- import json
3
-
4
- def yaml2json(yaml_data:yaml) -> json:
5
-
6
- output = json.dumps(yaml.safe_load(yaml_data), indent=2)
7
-
8
- return json.loads(output)
9
-
10
- def json2yaml(data) -> yaml:
11
-
12
- data=json.dumps(data,default=lambda o: del_none(o.__dict__))
13
-
14
- data=json.loads(data)
15
- output = yaml.dump(data)
16
-
17
- return output
18
-
19
- def del_none(d):
20
- """
21
- Delete keys with the value ``None`` in a dictionary, recursively.
22
- """
23
-
24
- for key, value in list(d.items()):
25
- if value is None:
26
- del d[key]
27
- elif isinstance(value, dict):
28
- del_none(value)
29
- return d
1
+ import yaml
2
+ import json
3
+
4
+ def yaml2json(yaml_data:yaml) -> json:
5
+
6
+ output = json.dumps(yaml.safe_load(yaml_data), indent=2)
7
+
8
+ return json.loads(output)
9
+
10
+ def json2yaml(data) -> yaml:
11
+
12
+ data=json.dumps(data,default=lambda o: del_none(o.__dict__))
13
+
14
+ data=json.loads(data)
15
+ output = yaml.dump(data)
16
+
17
+ return output
18
+
19
+ def del_none(d):
20
+ """
21
+ Delete keys with the value ``None`` in a dictionary, recursively.
22
+ """
23
+
24
+ for key, value in list(d.items()):
25
+ if value is None:
26
+ del d[key]
27
+ elif isinstance(value, dict):
28
+ del_none(value)
29
+ return d
@@ -1,2 +1,2 @@
1
- from .vector_database_repository import VectorRepository
1
+ from .vector_database_repository import VectorRepository
2
2
  # from .mongo_repository_base import MongoRepositoryBase
@@ -1,50 +1,51 @@
1
- from pymongo import MongoClient
2
- from bson import ObjectId
3
- from typing import Generic, TypeVar, List, Type
4
- from pydantic import BaseModel
5
- from sharedkernel.string_extentions import camel_to_snake
6
-
7
- T = TypeVar("T", bound=BaseModel)
8
-
9
-
10
- class MongoGenericRepository(Generic[T]):
11
- def __init__(self, database: MongoClient, model: Type[T]):
12
- self.database = database
13
- self.__collection_name = camel_to_snake(model.__name__)
14
- self.collection = self.database[self.__collection_name]
15
- self.model = model
16
-
17
- def _map_to_model(self, document: dict) -> T:
18
- document["id"] = str(document.pop("_id"))
19
- return self.model.parse_obj(document)
20
-
21
- def find_one(self, id: str) -> T:
22
- query = {"_id": ObjectId(id), "is_deleted": False}
23
- result = self.collection.find_one(query)
24
- return self._map_to_model(result) if result else None
25
-
26
- def insert_one(self, data: T) -> str:
27
- delattr(data, "id")
28
- result = self.collection.insert_one(data.dict())
29
- return str(result.inserted_id)
30
-
31
- def insert_many(self, data: List[T]) -> List[str]:
32
- data_list = [delattr(d.dict(), "id") for d in data]
33
- result = self.collection.insert_many(data_list)
34
- return [str(id_) for id_ in result.inserted_ids]
35
-
36
- def update_one(self, id: str, data: T) -> int:
37
- delattr(data, "id")
38
- query = {"_id": ObjectId(id)}
39
- result = self.collection.update_one(query, {"$set": data.dict()})
40
- return result.modified_count
41
-
42
- def delete_one(self, id: str) -> int:
43
- query = {"_id": ObjectId(id)}
44
- result = self.collection.delete_one(query)
45
- return result.deleted_count
46
-
47
- def get_all(self, page_number=1, page_size=10) -> List[T]:
48
- skip_count = (page_number - 1) * page_size
49
- result = self.collection.find().skip(skip_count).limit(page_size)
50
- return [self._map_to_model(doc) for doc in result]
1
+ from pymongo import MongoClient
2
+ from bson import ObjectId
3
+ from typing import Generic, TypeVar, List, Type
4
+ from pydantic import BaseModel
5
+ from sharedkernel.string_extentions import camel_to_snake
6
+
7
+ T = TypeVar("T", bound=BaseModel)
8
+
9
+
10
+ class MongoGenericRepository(Generic[T]):
11
+ def __init__(self, database: MongoClient, model: Type[T]):
12
+ self.database = database
13
+ self.__collection_name = camel_to_snake(model.__name__)
14
+ self.collection = self.database[self.__collection_name]
15
+ self.model = model
16
+
17
+ def _map_to_model(self, document: dict) -> T:
18
+ document["id"] = str(document.pop("_id"))
19
+ return self.model.parse_obj(document)
20
+
21
+ def find_one(self, id: str) -> T:
22
+ query = {"_id": ObjectId(id), "is_deleted": False}
23
+ result = self.collection.find_one(query)
24
+ return self._map_to_model(result) if result else None
25
+
26
+ def insert_one(self, data: T) -> str:
27
+ delattr(data, "id")
28
+ result = self.collection.insert_one(data.dict())
29
+ return str(result.inserted_id)
30
+
31
+ def insert_many(self, data: List[T]) -> List[str]:
32
+ data_list = [delattr(d.dict(), "id") for d in data]
33
+ result = self.collection.insert_many(data_list)
34
+ return [str(id_) for id_ in result.inserted_ids]
35
+
36
+ def update_one(self, id: str, data: T) -> int:
37
+ delattr(data, "id")
38
+ query = {"_id": ObjectId(id)}
39
+ result = self.collection.update_one(query, {"$set": data.dict()})
40
+ return result.modified_count
41
+
42
+ def delete_one(self, id: str) -> int:
43
+ query = {"_id": ObjectId(id)}
44
+ result = self.collection.delete_one(query)
45
+ return result.deleted_count
46
+
47
+ def get_all(self, page_number=1, page_size=10) -> List[T]:
48
+ skip_count = (page_number - 1) * page_size
49
+ query = {"_id": ObjectId(id), "is_deleted": False}
50
+ result = self.collection.find(query).skip(skip_count).limit(page_size)
51
+ return [self._map_to_model(doc) for doc in result]
@@ -1,3 +1,3 @@
1
- from .chroma_startegy import ChromaStrategy
2
- from .milvus_strategy import MilvusStrategy
1
+ from .chroma_startegy import ChromaStrategy
2
+ from .milvus_strategy import MilvusStrategy
3
3
  from .vector_database_repository import VectorRepository
@@ -1,39 +1,39 @@
1
- import chromadb
2
- import numpy as np
3
- from chromadb.config import Settings
4
- from .vector_database_strategy import VectorDatabaseStrategy
5
- import uuid
6
-
7
-
8
- class ChromaStrategy(VectorDatabaseStrategy):
9
- def __init__(self, collection_name: str):
10
- self.collection_name = collection_name
11
- self.collection = None
12
-
13
- def connect(self, host: str = "localhost", port: int = 8000):
14
- client = chromadb.Client(
15
- Settings(
16
- chroma_api_impl="rest",
17
- chroma_server_host=host,
18
- chroma_server_http_port=port,
19
- )
20
- )
21
- self.collection = client.get_or_create_collection(self.collection_name)
22
-
23
- def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
24
- id = str(uuid.uuid4())
25
- self.collection.upsert(ids=id, embeddings=vector.tolist(), metadatas=[metadata])
26
-
27
- return id
28
-
29
- def search_vector(self, vector: np.ndarray, top_k: int):
30
- results = self.collection.query(vectors=[vector.tolist()], n_results=top_k)
31
- return results
32
-
33
- def get_vector_by_id(self, id: str):
34
- result = self.collection.get(ids=id)
35
-
36
- return result
37
-
38
- def delete_vector(self, id: str):
39
- self.collection.delete(ids=id)
1
+ import chromadb
2
+ import numpy as np
3
+ from chromadb.config import Settings
4
+ from .vector_database_strategy import VectorDatabaseStrategy
5
+ import uuid
6
+
7
+
8
+ class ChromaStrategy(VectorDatabaseStrategy):
9
+ def __init__(self, collection_name: str):
10
+ self.collection_name = collection_name
11
+ self.collection = None
12
+
13
+ def connect(self, host: str = "localhost", port: int = 8000):
14
+ client = chromadb.Client(
15
+ Settings(
16
+ chroma_api_impl="rest",
17
+ chroma_server_host=host,
18
+ chroma_server_http_port=port,
19
+ )
20
+ )
21
+ self.collection = client.get_or_create_collection(self.collection_name)
22
+
23
+ def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
24
+ id = str(uuid.uuid4())
25
+ self.collection.upsert(ids=id, embeddings=vector.tolist(), metadatas=[metadata])
26
+
27
+ return id
28
+
29
+ def search_vector(self, vector: np.ndarray, top_k: int):
30
+ results = self.collection.query(vectors=[vector.tolist()], n_results=top_k)
31
+ return results
32
+
33
+ def get_vector_by_id(self, id: str):
34
+ result = self.collection.get(ids=id)
35
+
36
+ return result
37
+
38
+ def delete_vector(self, id: str):
39
+ self.collection.delete(ids=id)
@@ -1,50 +1,50 @@
1
- import numpy as np
2
- from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
3
-
4
- from .vector_database_strategy import VectorDatabaseStrategy
5
- import uuid
6
-
7
-
8
- class MilvusStrategy(VectorDatabaseStrategy):
9
- def __init__(self, collection_name: str):
10
- self.collection_name = collection_name
11
- self.collection = None
12
-
13
- def connect(self, host: str = "localhost", port: str = "19530"):
14
- connections.connect(alias="default", host=host, port=port)
15
- # Define fields
16
- fields = [
17
- FieldSchema(
18
- name="id", dtype=DataType.VARCHAR, max_length=36, is_primary=True
19
- ),
20
- FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128),
21
- ]
22
- schema = CollectionSchema(
23
- fields, description="Vector collection", enable_dynamic_field=True
24
- )
25
- self.collection = Collection(name=self.collection_name, schema=schema)
26
-
27
- if not self.collection.has_index():
28
- self.collection.create_index(field_name="vector", index_params={"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 128}})
29
-
30
- self.collection.load()
31
-
32
- def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
33
- id = str(uuid.uuid4())
34
- self.collection.insert(data={"id": id, "vector": vector.tolist()})
35
-
36
- return id
37
-
38
- def search_vector(self, vector: np.ndarray, top_k: int):
39
- search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
40
- results = self.collection.search(
41
- [vector.tolist()], "vector", search_params, top_k
42
- )
43
- return results
44
-
45
- def get_vector_by_id(self, id: str):
46
- result = self.collection.query(expr=f"id=='{id}'",output_fields=["vector"])
47
- return result
48
-
49
- def delete_vector(self, id: str):
50
- self.collection.delete(expr=f"id=='{id}'")
1
+ import numpy as np
2
+ from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
3
+
4
+ from .vector_database_strategy import VectorDatabaseStrategy
5
+ import uuid
6
+
7
+
8
+ class MilvusStrategy(VectorDatabaseStrategy):
9
+ def __init__(self, collection_name: str):
10
+ self.collection_name = collection_name
11
+ self.collection = None
12
+
13
+ def connect(self, host: str = "localhost", port: str = "19530"):
14
+ connections.connect(alias="default", host=host, port=port)
15
+ # Define fields
16
+ fields = [
17
+ FieldSchema(
18
+ name="id", dtype=DataType.VARCHAR, max_length=36, is_primary=True
19
+ ),
20
+ FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128),
21
+ ]
22
+ schema = CollectionSchema(
23
+ fields, description="Vector collection", enable_dynamic_field=True
24
+ )
25
+ self.collection = Collection(name=self.collection_name, schema=schema)
26
+
27
+ if not self.collection.has_index():
28
+ self.collection.create_index(field_name="vector", index_params={"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 128}})
29
+
30
+ self.collection.load()
31
+
32
+ def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
33
+ id = str(uuid.uuid4())
34
+ self.collection.insert(data={"id": id, "vector": vector.tolist()})
35
+
36
+ return id
37
+
38
+ def search_vector(self, vector: np.ndarray, top_k: int):
39
+ search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
40
+ results = self.collection.search(
41
+ [vector.tolist()], "vector", search_params, top_k
42
+ )
43
+ return results
44
+
45
+ def get_vector_by_id(self, id: str):
46
+ result = self.collection.query(expr=f"id=='{id}'",output_fields=["vector"])
47
+ return result
48
+
49
+ def delete_vector(self, id: str):
50
+ self.collection.delete(expr=f"id=='{id}'")
@@ -1,28 +1,28 @@
1
- import numpy as np
2
- from .vector_database_strategy import VectorDatabaseStrategy
3
- from .milvus_strategy import MilvusStrategy
4
- from .chroma_startegy import ChromaStrategy
5
- from sharedkernel.enum.vector_database_type import VectorDatabaseType
6
-
7
- class VectorRepository:
8
- def __init__(self, database_type: VectorDatabaseType, collection_name: str, **connection_params):
9
- self.strategy = self._get_strategy(database_type, collection_name)
10
- self.strategy.connect(**connection_params)
11
-
12
- def _get_strategy(self, database_type: VectorDatabaseType, collection_name: str) -> VectorDatabaseStrategy:
13
- if database_type == VectorDatabaseType.MILVUS:
14
- return MilvusStrategy(collection_name)
15
- else:
16
- return ChromaStrategy(collection_name)
17
-
18
- def add_vector(self, vector: np.ndarray, metadata: dict) -> str:
19
- return self.strategy.insert_vector(vector, metadata)
20
-
21
- def find_similar_vectors(self, vector: np.ndarray, top_k: int):
22
- return self.strategy.search_vector(vector, top_k)
23
-
24
- def remove_vector(self, id: str):
25
- self.strategy.delete_vector(id)
26
-
27
- def get_vector_by_id(self, id: str):
1
+ import numpy as np
2
+ from .vector_database_strategy import VectorDatabaseStrategy
3
+ from .milvus_strategy import MilvusStrategy
4
+ from .chroma_startegy import ChromaStrategy
5
+ from sharedkernel.enum.vector_database_type import VectorDatabaseType
6
+
7
+ class VectorRepository:
8
+ def __init__(self, database_type: VectorDatabaseType, collection_name: str, **connection_params):
9
+ self.strategy = self._get_strategy(database_type, collection_name)
10
+ self.strategy.connect(**connection_params)
11
+
12
+ def _get_strategy(self, database_type: VectorDatabaseType, collection_name: str) -> VectorDatabaseStrategy:
13
+ if database_type == VectorDatabaseType.MILVUS:
14
+ return MilvusStrategy(collection_name)
15
+ else:
16
+ return ChromaStrategy(collection_name)
17
+
18
+ def add_vector(self, vector: np.ndarray, metadata: dict) -> str:
19
+ return self.strategy.insert_vector(vector, metadata)
20
+
21
+ def find_similar_vectors(self, vector: np.ndarray, top_k: int):
22
+ return self.strategy.search_vector(vector, top_k)
23
+
24
+ def remove_vector(self, id: str):
25
+ self.strategy.delete_vector(id)
26
+
27
+ def get_vector_by_id(self, id: str):
28
28
  return self.strategy.get_vector_by_id(id)
@@ -1,22 +1,22 @@
1
- from abc import ABC, abstractmethod
2
- import numpy as np
3
-
4
- class VectorDatabaseStrategy(ABC):
5
- @abstractmethod
6
- def connect(self, **kwargs):
7
- pass
8
- @abstractmethod
9
- def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
10
- pass
11
-
12
- @abstractmethod
13
- def search_vector(self, vector: np.ndarray, top_k: int):
14
- pass
15
-
16
- @abstractmethod
17
- def delete_vector(self, id: str):
18
- pass
19
-
20
- @abstractmethod
21
- def get_vector_by_id(self, id: str):
22
- pass
1
+ from abc import ABC, abstractmethod
2
+ import numpy as np
3
+
4
+ class VectorDatabaseStrategy(ABC):
5
+ @abstractmethod
6
+ def connect(self, **kwargs):
7
+ pass
8
+ @abstractmethod
9
+ def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
10
+ pass
11
+
12
+ @abstractmethod
13
+ def search_vector(self, vector: np.ndarray, top_k: int):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def delete_vector(self, id: str):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_vector_by_id(self, id: str):
22
+ pass