sharedkernel 1.5.2__tar.gz → 1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {sharedkernel-1.5.2 → sharedkernel-1.6}/PKG-INFO +55 -53
  2. {sharedkernel-1.5.2 → sharedkernel-1.6}/README.md +3 -1
  3. {sharedkernel-1.5.2 → sharedkernel-1.6}/setup.cfg +4 -4
  4. {sharedkernel-1.5.2 → sharedkernel-1.6}/setup.py +40 -40
  5. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/common.py +29 -29
  6. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/database/__init__.py +1 -1
  7. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/database/mongo_generic_repository.py +52 -52
  8. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/database/vector_database_repository/__init__.py +2 -2
  9. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/database/vector_database_repository/chroma_startegy.py +39 -39
  10. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/database/vector_database_repository/milvus_strategy.py +50 -50
  11. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/database/vector_database_repository/vector_database_repository.py +27 -27
  12. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/database/vector_database_repository/vector_database_strategy.py +22 -22
  13. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/date_converter.py +122 -122
  14. sharedkernel-1.6/sharedkernel/enum/__init__.py +2 -0
  15. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/enum/error_code.py +14 -14
  16. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/enum/vector_database_type.py +5 -5
  17. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/exception/__init__.py +4 -4
  18. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/exception/exception.py +39 -39
  19. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/exception/exception_handlers.py +75 -75
  20. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/jwt_service.py +38 -38
  21. sharedkernel-1.6/sharedkernel/normalizer.py +770 -0
  22. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/objects/__init__.py +2 -2
  23. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/objects/base_document.py +9 -9
  24. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/objects/jwt_model.py +6 -6
  25. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/objects/result.py +28 -28
  26. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/regex_masking.py +92 -92
  27. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel/string_extentions.py +4 -4
  28. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel.egg-info/PKG-INFO +55 -53
  29. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel.egg-info/SOURCES.txt +1 -0
  30. sharedkernel-1.5.2/sharedkernel/enum/__init__.py +0 -2
  31. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel.egg-info/dependency_links.txt +0 -0
  32. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel.egg-info/requires.txt +0 -0
  33. {sharedkernel-1.5.2 → sharedkernel-1.6}/sharedkernel.egg-info/top_level.txt +0 -0
@@ -1,53 +1,55 @@
1
- Metadata-Version: 2.1
2
- Name: sharedkernel
3
- Version: 1.5.2
4
- Summary: sharekernel is an shared package between all python projects
5
- Author: Smilinno
6
- Description-Content-Type: text/markdown
7
- Requires-Dist: numpy
8
- Requires-Dist: requests
9
- Requires-Dist: pymongo
10
- Requires-Dist: fastapi==0.111.0
11
- Requires-Dist: PyJWT
12
- Requires-Dist: pymilvus
13
- Requires-Dist: chromadb
14
- Requires-Dist: persian_tools
15
- Requires-Dist: sentry-sdk
16
- Requires-Dist: jdatetime
17
- Requires-Dist: persiantools
18
-
19
- # SharedKernel
20
- this a shared kernel package
21
-
22
- # Change Log
23
- ### Version 1.5.2
24
- - add created_on to base_document
25
- ### Version 1.5.1
26
- - fix mongo repository bug
27
- ### Version 1.5.0
28
- - implement date converter
29
- - example: فردا - دیروز - یک ماه قبل
30
- ### Version 1.4.5
31
- - upgrade fastapi version
32
- ### Version 1.4.4
33
- - Fix regex masking bugs
34
- ### Version 1.4.3
35
- - Fix collection bug in MongoGenericRepository
36
- ### Version 1.4.2
37
- - Fix minor bugs
38
- ### Version 1.4.1
39
- - Fix minor bug in MongoGenericRepository
40
- ### Version 1.4.0
41
- - Implement date convertor for jalali and georgian
42
- ### Version 1.3.0
43
- - Implement Sentry For Log Exceptions
44
- ### Version 1.2.0
45
- - Implement Regex Masking
46
- # Create Package
47
- py -m pip install --upgrade build
48
- py -m build
49
- py -m pip install --upgrade twine
50
- py -m twine upload dist/*
51
-
52
- # Pypi
53
- pip install sharedkernel
1
+ Metadata-Version: 2.1
2
+ Name: sharedkernel
3
+ Version: 1.6
4
+ Summary: sharekernel is an shared package between all python projects
5
+ Author: Smilinno
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: numpy
8
+ Requires-Dist: requests
9
+ Requires-Dist: pymongo
10
+ Requires-Dist: fastapi==0.111.0
11
+ Requires-Dist: PyJWT
12
+ Requires-Dist: pymilvus
13
+ Requires-Dist: chromadb
14
+ Requires-Dist: persian_tools
15
+ Requires-Dist: sentry-sdk
16
+ Requires-Dist: jdatetime
17
+ Requires-Dist: persiantools
18
+
19
+ # SharedKernel
20
+ this a shared kernel package
21
+
22
+ # Change Log
23
+ ### Version 1.6
24
+ - Add phone number normalizer
25
+ ### Version 1.5.2
26
+ - Add created_on to base_document
27
+ ### Version 1.5.1
28
+ - fix mongo repository bug
29
+ ### Version 1.5.0
30
+ - implement date converter
31
+ - example: فردا - دیروز - یک ماه قبل
32
+ ### Version 1.4.5
33
+ - upgrade fastapi version
34
+ ### Version 1.4.4
35
+ - Fix regex masking bugs
36
+ ### Version 1.4.3
37
+ - Fix collection bug in MongoGenericRepository
38
+ ### Version 1.4.2
39
+ - Fix minor bugs
40
+ ### Version 1.4.1
41
+ - Fix minor bug in MongoGenericRepository
42
+ ### Version 1.4.0
43
+ - Implement date convertor for jalali and georgian
44
+ ### Version 1.3.0
45
+ - Implement Sentry For Log Exceptions
46
+ ### Version 1.2.0
47
+ - Implement Regex Masking
48
+ # Create Package
49
+ py -m pip install --upgrade build
50
+ py -m build
51
+ py -m pip install --upgrade twine
52
+ py -m twine upload dist/*
53
+
54
+ # Pypi
55
+ pip install sharedkernel
@@ -2,8 +2,10 @@
2
2
  this a shared kernel package
3
3
 
4
4
  # Change Log
5
+ ### Version 1.6
6
+ - Add phone number normalizer
5
7
  ### Version 1.5.2
6
- - add created_on to base_document
8
+ - Add created_on to base_document
7
9
  ### Version 1.5.1
8
10
  - fix mongo repository bug
9
11
  ### Version 1.5.0
@@ -1,4 +1,4 @@
1
- [egg_info]
2
- tag_build =
3
- tag_date = 0
4
-
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -1,40 +1,40 @@
1
- from setuptools import setup
2
-
3
- # read the contents of your README file
4
- from pathlib import Path
5
-
6
- this_directory = Path(__file__).parent
7
- long_description = (this_directory / "README.md").read_text()
8
-
9
- setup(
10
- # Needed to silence warnings (and to be a worthwhile package)
11
- name="sharedkernel",
12
- author="Smilinno",
13
- packages=[
14
- "sharedkernel",
15
- "sharedkernel.database",
16
- "sharedkernel.database.vector_database_repository",
17
- "sharedkernel.enum",
18
- "sharedkernel.exception",
19
- "sharedkernel.objects",
20
- ],
21
- # Needed for dependencies
22
- install_requires=[
23
- "numpy",
24
- "requests",
25
- "pymongo",
26
- "fastapi==0.111.0",
27
- "PyJWT",
28
- "pymilvus",
29
- "chromadb",
30
- "persian_tools",
31
- "sentry-sdk",
32
- "jdatetime",
33
- "persiantools"
34
- ],
35
- # *strongly* suggested for sharing
36
- version="1.5.2",
37
- description="sharekernel is an shared package between all python projects",
38
- long_description=long_description,
39
- long_description_content_type="text/markdown",
40
- )
1
+ from setuptools import setup
2
+
3
+ # read the contents of your README file
4
+ from pathlib import Path
5
+
6
+ this_directory = Path(__file__).parent
7
+ long_description = (this_directory / "README.md").read_text()
8
+
9
+ setup(
10
+ # Needed to silence warnings (and to be a worthwhile package)
11
+ name="sharedkernel",
12
+ author="Smilinno",
13
+ packages=[
14
+ "sharedkernel",
15
+ "sharedkernel.database",
16
+ "sharedkernel.database.vector_database_repository",
17
+ "sharedkernel.enum",
18
+ "sharedkernel.exception",
19
+ "sharedkernel.objects",
20
+ ],
21
+ # Needed for dependencies
22
+ install_requires=[
23
+ "numpy",
24
+ "requests",
25
+ "pymongo",
26
+ "fastapi==0.111.0",
27
+ "PyJWT",
28
+ "pymilvus",
29
+ "chromadb",
30
+ "persian_tools",
31
+ "sentry-sdk",
32
+ "jdatetime",
33
+ "persiantools"
34
+ ],
35
+ # *strongly* suggested for sharing
36
+ version="1.6",
37
+ description="sharekernel is an shared package between all python projects",
38
+ long_description=long_description,
39
+ long_description_content_type="text/markdown",
40
+ )
@@ -1,29 +1,29 @@
1
- import yaml
2
- import json
3
-
4
- def yaml2json(yaml_data:yaml) -> json:
5
-
6
- output = json.dumps(yaml.safe_load(yaml_data), indent=2)
7
-
8
- return json.loads(output)
9
-
10
- def json2yaml(data) -> yaml:
11
-
12
- data=json.dumps(data,default=lambda o: del_none(o.__dict__))
13
-
14
- data=json.loads(data)
15
- output = yaml.dump(data)
16
-
17
- return output
18
-
19
- def del_none(d):
20
- """
21
- Delete keys with the value ``None`` in a dictionary, recursively.
22
- """
23
-
24
- for key, value in list(d.items()):
25
- if value is None:
26
- del d[key]
27
- elif isinstance(value, dict):
28
- del_none(value)
29
- return d
1
+ import yaml
2
+ import json
3
+
4
+ def yaml2json(yaml_data:yaml) -> json:
5
+
6
+ output = json.dumps(yaml.safe_load(yaml_data), indent=2)
7
+
8
+ return json.loads(output)
9
+
10
+ def json2yaml(data) -> yaml:
11
+
12
+ data=json.dumps(data,default=lambda o: del_none(o.__dict__))
13
+
14
+ data=json.loads(data)
15
+ output = yaml.dump(data)
16
+
17
+ return output
18
+
19
+ def del_none(d):
20
+ """
21
+ Delete keys with the value ``None`` in a dictionary, recursively.
22
+ """
23
+
24
+ for key, value in list(d.items()):
25
+ if value is None:
26
+ del d[key]
27
+ elif isinstance(value, dict):
28
+ del_none(value)
29
+ return d
@@ -1,2 +1,2 @@
1
- from .vector_database_repository import VectorRepository
1
+ from .vector_database_repository import VectorRepository
2
2
  # from .mongo_repository_base import MongoRepositoryBase
@@ -1,52 +1,52 @@
1
- from pymongo import MongoClient
2
- from bson import ObjectId
3
- from typing import Generic, TypeVar, List, Type
4
- from pydantic import BaseModel
5
- from sharedkernel.string_extentions import camel_to_snake
6
-
7
- T = TypeVar("T", bound=BaseModel)
8
-
9
-
10
- class MongoGenericRepository(Generic[T]):
11
- def __init__(self, database: MongoClient, model: Type[T]):
12
- self.database = database
13
- self.__collection_name = camel_to_snake(model.__name__)
14
- self.collection = self.database[self.__collection_name]
15
- self.model = model
16
-
17
- def _map_to_model(self, document: dict) -> T:
18
- document["id"] = str(document.pop("_id"))
19
- return self.model.parse_obj(document)
20
-
21
- def find_one(self, id: str) -> T:
22
- query = {"_id": ObjectId(id), "is_deleted": False}
23
- result = self.collection.find_one(query)
24
- return self._map_to_model(result) if result else None
25
-
26
-
27
- def insert_one(self, data: T) -> str:
28
- delattr(data, "id")
29
- result = self.collection.insert_one(data.dict())
30
- return str(result.inserted_id)
31
-
32
- def insert_many(self, data: List[T]) -> List[str]:
33
- data_list = [delattr(d.dict(), "id") for d in data]
34
- result = self.collection.insert_many(data_list)
35
- return [str(id_) for id_ in result.inserted_ids]
36
-
37
- def update_one(self, id: str, data: T) -> int:
38
- delattr(data, "id")
39
- query = {"_id": ObjectId(id)}
40
- result = self.collection.update_one(query, {"$set": data.dict()})
41
- return result.modified_count
42
-
43
- def delete_one(self, id: str) -> int:
44
- query = {"_id": ObjectId(id)}
45
- result = self.collection.delete_one(query)
46
- return result.deleted_count
47
-
48
- def get_all(self, page_number=1, page_size=10) -> List[T]:
49
- skip_count = (page_number - 1) * page_size
50
- query = {"is_deleted": False}
51
- result = self.collection.find(query).skip(skip_count).limit(page_size)
52
- return [self._map_to_model(doc) for doc in result]
1
+ from pymongo import MongoClient
2
+ from bson import ObjectId
3
+ from typing import Generic, TypeVar, List, Type
4
+ from pydantic import BaseModel
5
+ from sharedkernel.string_extentions import camel_to_snake
6
+
7
+ T = TypeVar("T", bound=BaseModel)
8
+
9
+
10
+ class MongoGenericRepository(Generic[T]):
11
+ def __init__(self, database: MongoClient, model: Type[T]):
12
+ self.database = database
13
+ self.__collection_name = camel_to_snake(model.__name__)
14
+ self.collection = self.database[self.__collection_name]
15
+ self.model = model
16
+
17
+ def _map_to_model(self, document: dict) -> T:
18
+ document["id"] = str(document.pop("_id"))
19
+ return self.model.parse_obj(document)
20
+
21
+ def find_one(self, id: str) -> T:
22
+ query = {"_id": ObjectId(id), "is_deleted": False}
23
+ result = self.collection.find_one(query)
24
+ return self._map_to_model(result) if result else None
25
+
26
+
27
+ def insert_one(self, data: T) -> str:
28
+ delattr(data, "id")
29
+ result = self.collection.insert_one(data.dict())
30
+ return str(result.inserted_id)
31
+
32
+ def insert_many(self, data: List[T]) -> List[str]:
33
+ data_list = [delattr(d.dict(), "id") for d in data]
34
+ result = self.collection.insert_many(data_list)
35
+ return [str(id_) for id_ in result.inserted_ids]
36
+
37
+ def update_one(self, id: str, data: T) -> int:
38
+ delattr(data, "id")
39
+ query = {"_id": ObjectId(id)}
40
+ result = self.collection.update_one(query, {"$set": data.dict()})
41
+ return result.modified_count
42
+
43
+ def delete_one(self, id: str) -> int:
44
+ query = {"_id": ObjectId(id)}
45
+ result = self.collection.delete_one(query)
46
+ return result.deleted_count
47
+
48
+ def get_all(self, page_number=1, page_size=10) -> List[T]:
49
+ skip_count = (page_number - 1) * page_size
50
+ query = {"is_deleted": False}
51
+ result = self.collection.find(query).skip(skip_count).limit(page_size)
52
+ return [self._map_to_model(doc) for doc in result]
@@ -1,3 +1,3 @@
1
- from .chroma_startegy import ChromaStrategy
2
- from .milvus_strategy import MilvusStrategy
1
+ from .chroma_startegy import ChromaStrategy
2
+ from .milvus_strategy import MilvusStrategy
3
3
  from .vector_database_repository import VectorRepository
@@ -1,39 +1,39 @@
1
- import chromadb
2
- import numpy as np
3
- from chromadb.config import Settings
4
- from .vector_database_strategy import VectorDatabaseStrategy
5
- import uuid
6
-
7
-
8
- class ChromaStrategy(VectorDatabaseStrategy):
9
- def __init__(self, collection_name: str):
10
- self.collection_name = collection_name
11
- self.collection = None
12
-
13
- def connect(self, host: str = "localhost", port: int = 8000):
14
- client = chromadb.Client(
15
- Settings(
16
- chroma_api_impl="rest",
17
- chroma_server_host=host,
18
- chroma_server_http_port=port,
19
- )
20
- )
21
- self.collection = client.get_or_create_collection(self.collection_name)
22
-
23
- def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
24
- id = str(uuid.uuid4())
25
- self.collection.upsert(ids=id, embeddings=vector.tolist(), metadatas=[metadata])
26
-
27
- return id
28
-
29
- def search_vector(self, vector: np.ndarray, top_k: int):
30
- results = self.collection.query(vectors=[vector.tolist()], n_results=top_k)
31
- return results
32
-
33
- def get_vector_by_id(self, id: str):
34
- result = self.collection.get(ids=id)
35
-
36
- return result
37
-
38
- def delete_vector(self, id: str):
39
- self.collection.delete(ids=id)
1
+ import chromadb
2
+ import numpy as np
3
+ from chromadb.config import Settings
4
+ from .vector_database_strategy import VectorDatabaseStrategy
5
+ import uuid
6
+
7
+
8
+ class ChromaStrategy(VectorDatabaseStrategy):
9
+ def __init__(self, collection_name: str):
10
+ self.collection_name = collection_name
11
+ self.collection = None
12
+
13
+ def connect(self, host: str = "localhost", port: int = 8000):
14
+ client = chromadb.Client(
15
+ Settings(
16
+ chroma_api_impl="rest",
17
+ chroma_server_host=host,
18
+ chroma_server_http_port=port,
19
+ )
20
+ )
21
+ self.collection = client.get_or_create_collection(self.collection_name)
22
+
23
+ def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
24
+ id = str(uuid.uuid4())
25
+ self.collection.upsert(ids=id, embeddings=vector.tolist(), metadatas=[metadata])
26
+
27
+ return id
28
+
29
+ def search_vector(self, vector: np.ndarray, top_k: int):
30
+ results = self.collection.query(vectors=[vector.tolist()], n_results=top_k)
31
+ return results
32
+
33
+ def get_vector_by_id(self, id: str):
34
+ result = self.collection.get(ids=id)
35
+
36
+ return result
37
+
38
+ def delete_vector(self, id: str):
39
+ self.collection.delete(ids=id)
@@ -1,50 +1,50 @@
1
- import numpy as np
2
- from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
3
-
4
- from .vector_database_strategy import VectorDatabaseStrategy
5
- import uuid
6
-
7
-
8
- class MilvusStrategy(VectorDatabaseStrategy):
9
- def __init__(self, collection_name: str):
10
- self.collection_name = collection_name
11
- self.collection = None
12
-
13
- def connect(self, host: str = "localhost", port: str = "19530"):
14
- connections.connect(alias="default", host=host, port=port)
15
- # Define fields
16
- fields = [
17
- FieldSchema(
18
- name="id", dtype=DataType.VARCHAR, max_length=36, is_primary=True
19
- ),
20
- FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128),
21
- ]
22
- schema = CollectionSchema(
23
- fields, description="Vector collection", enable_dynamic_field=True
24
- )
25
- self.collection = Collection(name=self.collection_name, schema=schema)
26
-
27
- if not self.collection.has_index():
28
- self.collection.create_index(field_name="vector", index_params={"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 128}})
29
-
30
- self.collection.load()
31
-
32
- def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
33
- id = str(uuid.uuid4())
34
- self.collection.insert(data={"id": id, "vector": vector.tolist()})
35
-
36
- return id
37
-
38
- def search_vector(self, vector: np.ndarray, top_k: int):
39
- search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
40
- results = self.collection.search(
41
- [vector.tolist()], "vector", search_params, top_k
42
- )
43
- return results
44
-
45
- def get_vector_by_id(self, id: str):
46
- result = self.collection.query(expr=f"id=='{id}'",output_fields=["vector"])
47
- return result
48
-
49
- def delete_vector(self, id: str):
50
- self.collection.delete(expr=f"id=='{id}'")
1
+ import numpy as np
2
+ from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
3
+
4
+ from .vector_database_strategy import VectorDatabaseStrategy
5
+ import uuid
6
+
7
+
8
+ class MilvusStrategy(VectorDatabaseStrategy):
9
+ def __init__(self, collection_name: str):
10
+ self.collection_name = collection_name
11
+ self.collection = None
12
+
13
+ def connect(self, host: str = "localhost", port: str = "19530"):
14
+ connections.connect(alias="default", host=host, port=port)
15
+ # Define fields
16
+ fields = [
17
+ FieldSchema(
18
+ name="id", dtype=DataType.VARCHAR, max_length=36, is_primary=True
19
+ ),
20
+ FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128),
21
+ ]
22
+ schema = CollectionSchema(
23
+ fields, description="Vector collection", enable_dynamic_field=True
24
+ )
25
+ self.collection = Collection(name=self.collection_name, schema=schema)
26
+
27
+ if not self.collection.has_index():
28
+ self.collection.create_index(field_name="vector", index_params={"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 128}})
29
+
30
+ self.collection.load()
31
+
32
+ def insert_vector(self, vector: np.ndarray, metadata: dict) -> str:
33
+ id = str(uuid.uuid4())
34
+ self.collection.insert(data={"id": id, "vector": vector.tolist()})
35
+
36
+ return id
37
+
38
+ def search_vector(self, vector: np.ndarray, top_k: int):
39
+ search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
40
+ results = self.collection.search(
41
+ [vector.tolist()], "vector", search_params, top_k
42
+ )
43
+ return results
44
+
45
+ def get_vector_by_id(self, id: str):
46
+ result = self.collection.query(expr=f"id=='{id}'",output_fields=["vector"])
47
+ return result
48
+
49
+ def delete_vector(self, id: str):
50
+ self.collection.delete(expr=f"id=='{id}'")
@@ -1,28 +1,28 @@
1
- import numpy as np
2
- from .vector_database_strategy import VectorDatabaseStrategy
3
- from .milvus_strategy import MilvusStrategy
4
- from .chroma_startegy import ChromaStrategy
5
- from sharedkernel.enum.vector_database_type import VectorDatabaseType
6
-
7
- class VectorRepository:
8
- def __init__(self, database_type: VectorDatabaseType, collection_name: str, **connection_params):
9
- self.strategy = self._get_strategy(database_type, collection_name)
10
- self.strategy.connect(**connection_params)
11
-
12
- def _get_strategy(self, database_type: VectorDatabaseType, collection_name: str) -> VectorDatabaseStrategy:
13
- if database_type == VectorDatabaseType.MILVUS:
14
- return MilvusStrategy(collection_name)
15
- else:
16
- return ChromaStrategy(collection_name)
17
-
18
- def add_vector(self, vector: np.ndarray, metadata: dict) -> str:
19
- return self.strategy.insert_vector(vector, metadata)
20
-
21
- def find_similar_vectors(self, vector: np.ndarray, top_k: int):
22
- return self.strategy.search_vector(vector, top_k)
23
-
24
- def remove_vector(self, id: str):
25
- self.strategy.delete_vector(id)
26
-
27
- def get_vector_by_id(self, id: str):
1
+ import numpy as np
2
+ from .vector_database_strategy import VectorDatabaseStrategy
3
+ from .milvus_strategy import MilvusStrategy
4
+ from .chroma_startegy import ChromaStrategy
5
+ from sharedkernel.enum.vector_database_type import VectorDatabaseType
6
+
7
+ class VectorRepository:
8
+ def __init__(self, database_type: VectorDatabaseType, collection_name: str, **connection_params):
9
+ self.strategy = self._get_strategy(database_type, collection_name)
10
+ self.strategy.connect(**connection_params)
11
+
12
+ def _get_strategy(self, database_type: VectorDatabaseType, collection_name: str) -> VectorDatabaseStrategy:
13
+ if database_type == VectorDatabaseType.MILVUS:
14
+ return MilvusStrategy(collection_name)
15
+ else:
16
+ return ChromaStrategy(collection_name)
17
+
18
+ def add_vector(self, vector: np.ndarray, metadata: dict) -> str:
19
+ return self.strategy.insert_vector(vector, metadata)
20
+
21
+ def find_similar_vectors(self, vector: np.ndarray, top_k: int):
22
+ return self.strategy.search_vector(vector, top_k)
23
+
24
+ def remove_vector(self, id: str):
25
+ self.strategy.delete_vector(id)
26
+
27
+ def get_vector_by_id(self, id: str):
28
28
  return self.strategy.get_vector_by_id(id)