media-tagging 0.2.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. media-tagging-0.2.0.dev1/PKG-INFO +13 -0
  2. media-tagging-0.2.0.dev1/README.md +48 -0
  3. media-tagging-0.2.0.dev1/entrypoints/__init__.py +0 -0
  4. media-tagging-0.2.0.dev1/entrypoints/cli.py +93 -0
  5. media-tagging-0.2.0.dev1/entrypoints/server.py +101 -0
  6. media-tagging-0.2.0.dev1/media_tagging/__init__.py +0 -0
  7. media-tagging-0.2.0.dev1/media_tagging/llms.py +63 -0
  8. media-tagging-0.2.0.dev1/media_tagging/tagger.py +60 -0
  9. media-tagging-0.2.0.dev1/media_tagging/taggers/__init__.py +0 -0
  10. media-tagging-0.2.0.dev1/media_tagging/taggers/api.py +119 -0
  11. media-tagging-0.2.0.dev1/media_tagging/taggers/base.py +115 -0
  12. media-tagging-0.2.0.dev1/media_tagging/taggers/llm.py +192 -0
  13. media-tagging-0.2.0.dev1/media_tagging/tools.py +70 -0
  14. media-tagging-0.2.0.dev1/media_tagging/utils.py +20 -0
  15. media-tagging-0.2.0.dev1/media_tagging/writer.py +153 -0
  16. media-tagging-0.2.0.dev1/media_tagging.egg-info/PKG-INFO +13 -0
  17. media-tagging-0.2.0.dev1/media_tagging.egg-info/SOURCES.txt +28 -0
  18. media-tagging-0.2.0.dev1/media_tagging.egg-info/dependency_links.txt +1 -0
  19. media-tagging-0.2.0.dev1/media_tagging.egg-info/entry_points.txt +2 -0
  20. media-tagging-0.2.0.dev1/media_tagging.egg-info/requires.txt +11 -0
  21. media-tagging-0.2.0.dev1/media_tagging.egg-info/top_level.txt +3 -0
  22. media-tagging-0.2.0.dev1/setup.cfg +4 -0
  23. media-tagging-0.2.0.dev1/setup.py +54 -0
  24. media-tagging-0.2.0.dev1/tests/__init__.py +0 -0
  25. media-tagging-0.2.0.dev1/tests/conftest.py +34 -0
  26. media-tagging-0.2.0.dev1/tests/end_to_end/__init__.py +0 -0
  27. media-tagging-0.2.0.dev1/tests/end_to_end/test_main.py +48 -0
  28. media-tagging-0.2.0.dev1/tests/unit/__init__.py +0 -0
  29. media-tagging-0.2.0.dev1/tests/unit/test_tagger.py +84 -0
  30. media-tagging-0.2.0.dev1/tests/unit/test_writer.py +119 -0
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.1
2
+ Name: media-tagging
3
+ Version: 0.2.0.dev1
4
+ Author: Google Inc. (gTech gPS CSE team)
5
+ Author-email: no-reply@google.com
6
+ License: Apache 2.0
7
+ Classifier: Programming Language :: Python :: 3 :: Only
8
+ Classifier: Programming Language :: Python :: 3.11
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: License :: OSI Approved :: Apache Software License
13
+ Description-Content-Type: text/markdown
@@ -0,0 +1,48 @@
1
+ # Welltech Media Tagging
2
+
3
+ ## Prerequisites
4
+
5
+ * Google Cloud project with billing enabled.
6
+ * [Video Intelligence API](https://console.cloud.google.com/apis/library/videointelligence.googleapis.com) and [Vision API](https://console.cloud.google.com/apis/library/vision.googleapis.com) enabled.
7
+ * Python3.8+
8
+ * Access to repository configured. In order to clone this repository you need
9
+ to do the following:
10
+ * Visit https://professional-services.googlesource.com/new-password and
11
+ login with your account.
12
+ * Once authenticated please copy all lines in box
13
+ and paste them in the terminal.
14
+
15
+
16
+ ## Run
17
+
18
+
19
+ 1. Install `media-tagger`
20
+
21
+ ```
22
+ pip install media-tagging
23
+ ```
24
+
25
+ 2. Perform tagging
26
+
27
+ ```
28
+ media-tagger --media-path MEDIA_PATH --tagger TAGGER_TYPE --writer WRITER_TYPE
29
+ ```
30
+ where:
31
+ * MEDIA_PATH - comma-separated names of files for tagging (can be urls).
32
+ * TAGGER_TYPE - name of tagger, supported options:
33
+ * `vision-api` - tags images based on [Google Cloud Vision API](https://cloud.google.com/vision/),
34
+ * `video-api` for videos based on [Google Cloud Video Intelligence API](https://cloud.google.com/video-intelligence/)
35
+ * `gemini-image` - Uses Gemini to tags images. Add `--tagger.n_tags=<N_TAGS>`
36
+ parameter to control number of tags returned by tagger.
37
+ * `gemini-structured-image` - Uses Gemini to find certain tags in the images.
38
+ Add `--tagger.tags='tag1, tag2, ..., tagN` parameter to find certain tags
39
+ in the image.
40
+ * `gemini-description-image` - Provides brief description of the image,
41
+ * WRITER_TYPE - name of writer, one of `csv`, `json`
42
+
43
+ By default script will create a single file with tagging results for each media_path.
44
+ If you want to combine results into a single file add `--output OUTPUT_NAME` flag (without extension, i.e. `--output tagging_sample`.
45
+
46
+
47
+ ## Disclaimer
48
+ This is not an officially supported Google product.
File without changes
@@ -0,0 +1,93 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Provides CLI for media tagging."""
15
+
16
+ import argparse
17
+ import logging
18
+ import os
19
+
20
+ import smart_open
21
+ from gaarf.cli import utils as gaarf_utils
22
+
23
+ from media_tagging import tagger, utils, writer
24
+ from media_tagging.taggers import base as base_tagger
25
+
26
+
27
+ def tag_media(
28
+ media_path: str | os.PathLike,
29
+ tagger_type: base_tagger.BaseTagger,
30
+ writer_type: writer.BaseWriter = writer.JsonWriter(),
31
+ single_output_name: str | None = None,
32
+ tagging_parameters: dict[str, str] | None = None,
33
+ ) -> None:
34
+ """Runs media tagging algorithm.
35
+
36
+ Args:
37
+ media_path: Local or remote path to media file.
38
+ tagger_type: Initialized tagger.
39
+ writer_type: Initialized writer for saving tagging results.
40
+ single_output_name: Parameter for saving results to a single file.
41
+ tagging_parameters: Optional keywords arguments to be sent for tagging.
42
+ """
43
+ media_paths = media_path.split(',')
44
+ if not tagging_parameters:
45
+ tagging_parameters = {}
46
+ results = []
47
+ for path in media_paths:
48
+ media_name = utils.convert_path_to_media_name(path)
49
+ logging.info('Processing media: %s', path)
50
+ with smart_open.open(path, 'rb') as f:
51
+ media_bytes = f.read()
52
+ results.append(
53
+ tagger_type.tag(
54
+ media_name,
55
+ media_bytes,
56
+ tagging_options=base_tagger.TaggingOptions(**tagging_parameters),
57
+ )
58
+ )
59
+ writer_type.write(results, single_output_name)
60
+
61
+
62
+ def main():
63
+ """Main entrypoint."""
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument('--media-path', dest='media_path')
66
+ parser.add_argument('--tagger', dest='tagger', default='vision-api')
67
+ parser.add_argument('--writer', dest='writer', default='json')
68
+ parser.add_argument('--output-to-file', dest='output', default=None)
69
+ parser.add_argument('--loglevel', dest='loglevel', default='INFO')
70
+ args, kwargs = parser.parse_known_args()
71
+
72
+ concrete_tagger = tagger.create_tagger(args.tagger)
73
+ concrete_writer = writer.create_writer(args.writer)
74
+ tagging_parameters = gaarf_utils.ParamsParser(['tagger']).parse(kwargs)
75
+
76
+ logging.basicConfig(
77
+ format='[%(asctime)s][%(name)s][%(levelname)s] %(message)s',
78
+ level=args.loglevel,
79
+ datefmt='%Y-%m-%d %H:%M:%S',
80
+ )
81
+ logging.getLogger(__file__)
82
+
83
+ tag_media(
84
+ media_path=args.media_path,
85
+ tagger_type=concrete_tagger,
86
+ writer_type=concrete_writer,
87
+ single_output_name=args.output,
88
+ tagging_parameters=tagging_parameters.get('tagger'),
89
+ )
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Provides HTTP endpoint for media tagging."""
15
+
16
+ import logging
17
+
18
+ import fastapi
19
+ import smart_open
20
+ from typing_extensions import TypedDict
21
+
22
+ from media_tagging import tagger, utils
23
+ from media_tagging.tagger import base as base_tagger
24
+
25
+ taggers: dict[str, base_tagger.BaseTagger] = {}
26
+ app = fastapi.FastAPI()
27
+
28
+
29
+ class MediaPostRequest(TypedDict):
30
+ """Specifies structure of request for tagging media.
31
+
32
+ Attributes:
33
+ media_url: Local or remote URL of media.
34
+ """
35
+
36
+ media_url: str
37
+ tagging_parameters: dict[str, int | list[str]]
38
+
39
+
40
+ @app.post('/tagger/llm')
41
+ async def tag_with_llm(
42
+ data: MediaPostRequest = fastapi.Body(embed=True),
43
+ ) -> dict[str, str]:
44
+ """Performs media tagging via LLMs.
45
+
46
+ Args:
47
+ data: Post request for media tagging.
48
+
49
+ Returns:
50
+ Json results of tagging.
51
+ """
52
+ if not (llm_tagger := taggers.get('gemini-image')):
53
+ llm_tagger = tagger.create_tagger('gemini-image')
54
+ taggers['gemini-image'] = llm_tagger
55
+ if media_url := data.get('media_url'):
56
+ media_name = utils.convert_path_to_media_name(media_url)
57
+ logging.info('Processing media: %s', media_url)
58
+ with smart_open.open(media_url, 'rb') as f:
59
+ media_bytes = f.read()
60
+ tagging_options = base_tagger.TaggingOptions(
61
+ **data.get('tagging_parameters')
62
+ )
63
+ tagging_result = llm_tagger.tag(
64
+ name=media_name, content=media_bytes, tagging_options=tagging_options
65
+ )
66
+ return fastapi.responses.JSONResponse(
67
+ content=fastapi.encoders.jsonable_encoder(tagging_result.dict())
68
+ )
69
+ raise ValueError('No path to media is provided.')
70
+
71
+
72
+ @app.post('/tagger/api')
73
+ async def tag_with_api(
74
+ data: MediaPostRequest = fastapi.Body(embed=True),
75
+ ) -> dict[str, str]:
76
+ """Performs media tagging via Google Cloud APIs.
77
+
78
+ Args:
79
+ data: Post request for media tagging.
80
+
81
+ Returns:
82
+ Json results of tagging.
83
+ """
84
+ if not (api_tagger := taggers.get('vision-api')):
85
+ api_tagger = tagger.create_tagger('vision-api')
86
+ taggers['vision-api'] = api_tagger
87
+ if media_url := data.get('media_url'):
88
+ media_name = utils.convert_path_to_media_name(media_url)
89
+ logging.info('Processing media: %s', media_url)
90
+ with smart_open.open(media_url, 'rb') as f:
91
+ media_bytes = f.read()
92
+ tagging_options = base_tagger.TaggingOptions(
93
+ **data.get('tagging_parameters')
94
+ )
95
+ tagging_result = api_tagger.tag(
96
+ name=media_name, content=media_bytes, tagging_options=tagging_options
97
+ )
98
+ return fastapi.responses.JSONResponse(
99
+ content=fastapi.encoders.jsonable_encoder(tagging_result.dict())
100
+ )
101
+ raise ValueError('No path to media is provided.')
File without changes
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Module for defining various LLMs."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+
20
+ import langchain_google_genai as genai
21
+ from langchain_core import language_models
22
+
23
+ _GEMINI_SAFETY_SETTINGS: dict[genai.HarmCategory, genai.HarmBlockThreshold] = {
24
+ genai.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.HarmBlockThreshold.BLOCK_NONE, # noqa: E501
25
+ genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.HarmBlockThreshold.BLOCK_NONE, # noqa: E501
26
+ genai.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.HarmBlockThreshold.BLOCK_NONE, # noqa: E501
27
+ }
28
+
29
+
30
+ def create_llm(
31
+ llm_type: str, llm_parameters: dict[str, str] | None = None
32
+ ) -> language_models.BaseLanguageModel:
33
+ """Creates LLM based on type and parameters.
34
+
35
+ Args:
36
+ llm_type: Type of LLM to instantiate.
37
+ llm_parameters: Various parameters to instantiate LLM.
38
+
39
+ Returns:
40
+ Initialized LLM.
41
+
42
+ Raises:
43
+ InvalidLLMTypeError: When incorrect LLM type is specified.
44
+ """
45
+ mapping = {
46
+ 'gemini': genai.ChatGoogleGenerativeAI,
47
+ }
48
+ if llm := mapping.get(llm_type):
49
+ if not llm_parameters:
50
+ llm_parameters = {}
51
+ if llm_type == 'gemini':
52
+ llm_parameters.update(
53
+ {
54
+ 'safety_settings': _GEMINI_SAFETY_SETTINGS,
55
+ 'google_api_key': os.environ.get('GOOGLE_API_KEY'),
56
+ }
57
+ )
58
+ return llm(**llm_parameters)
59
+ raise InvalidLLMTypeError(f'Unsupported LLM type: {llm_type}')
60
+
61
+
62
+ class InvalidLLMTypeError(Exception):
63
+ """Error when incorrect LLM type is specified."""
@@ -0,0 +1,60 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Module for performing media tagging.
15
+
16
+ Media tagging sends API requests to tagging engine (i.e. Google Vision API)
17
+ and returns tagging results that can be easily written.
18
+ """
19
+
20
+ from media_tagging.taggers import api, base, llm
21
+
22
+ _TAGGERS = {
23
+ 'vision-api': api.GoogleVisionAPITagger,
24
+ 'video-api': api.GoogleVideoIntelligenceAPITagger,
25
+ 'gemini-image': llm.GeminiImageTagger,
26
+ 'gemini-structured-image': llm.GeminiImageTagger,
27
+ 'gemini-description-image': llm.GeminiImageTagger,
28
+ }
29
+
30
+ _LLM_TAGGERS_TYPES = {
31
+ 'gemini-image': llm.LLMTaggerTypeEnum.UNSTRUCTURED,
32
+ 'gemini-structured-image': llm.LLMTaggerTypeEnum.STRUCTURED,
33
+ 'gemini-description-image': llm.LLMTaggerTypeEnum.DESCRIPTION,
34
+ }
35
+
36
+
37
+ def create_tagger(
38
+ tagger_type: str, tagger_parameters: dict[str, str] | None = None
39
+ ) -> base.BaseTagger:
40
+ """Factory for creating taggers based on provided type.
41
+
42
+ Args:
43
+ tagger_type: Type of tagger.
44
+ tagger_parameters: Various parameters to instantiate tagger.
45
+
46
+ Returns:
47
+ Concrete tagger class.
48
+ """
49
+ if not tagger_parameters:
50
+ tagger_parameters = {}
51
+ if tagger := _TAGGERS.get(tagger_type):
52
+ if issubclass(tagger, llm.LLMTagger):
53
+ return tagger(
54
+ tagger_type=_LLM_TAGGERS_TYPES.get(tagger_type), **tagger_parameters
55
+ )
56
+ return tagger(**tagger_parameters)
57
+ raise ValueError(
58
+ f'Incorrect tagger {type} is provided, '
59
+ f'valid options: {list(_TAGGERS.keys())}'
60
+ )
@@ -0,0 +1,119 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Module for performing media tagging via Google APIs.
15
+
16
+ Media tagging sends API requests to tagging engine (i.e. Google Vision API)
17
+ and returns tagging results that can be easily written.
18
+ """
19
+
20
+ from collections import defaultdict
21
+
22
+ from google.cloud import videointelligence, vision
23
+ from typing_extensions import override
24
+
25
+ from media_tagging.taggers import base
26
+
27
+
28
+ class GoogleVisionAPITagger(base.BaseTagger):
29
+ """Tagger responsible for getting image tags from Cloud Vision API.
30
+
31
+ Attributes:
32
+ client: Vision API client responsible to tagging.
33
+ """
34
+
35
+ def __init__(self, project: str | None = None) -> None:
36
+ """Initializes GoogleVisionAPITagger with a given project.
37
+
38
+ Args:
39
+ project: Google Cloud project id.
40
+ """
41
+ self._project = project
42
+
43
+ @property
44
+ def client(self) -> vision.ImageAnnotatorClient:
45
+ """Creates ImageAnnotatorClient."""
46
+ return vision.ImageAnnotatorClient()
47
+
48
+ @override
49
+ def tag(
50
+ self,
51
+ name: str,
52
+ content: bytes,
53
+ tagging_options: base.TaggingOptions = base.TaggingOptions(),
54
+ **kwargs: str,
55
+ ) -> base.TaggingResult:
56
+ image = vision.Image(content=content)
57
+ response = self.client.label_detection(image=image)
58
+ tags = [
59
+ base.Tag(name=r.description, score=r.score)
60
+ for r in response.label_annotations
61
+ ]
62
+ if n_tags := tagging_options.n_tags:
63
+ tags = self._limit_number_of_tags(tags, n_tags)
64
+ return base.TaggingResult(identifier=name, type='image', content=tags)
65
+
66
+
67
+ class GoogleVideoIntelligenceAPITagger(base.BaseTagger):
68
+ """Tagger responsible for getting image tags from Video Intelligence API.
69
+
70
+ Attributes:
71
+ client: Video Intelligence API client responsible to tagging.
72
+ """
73
+
74
+ def __init__(self, project: str | None = None) -> None:
75
+ """Initializes GoogleVideoIntelligenceAPITagger with a given project.
76
+
77
+ Args:
78
+ project: Google Cloud project id.
79
+ """
80
+ self._project = project
81
+
82
+ @property
83
+ def client(self) -> videointelligence.VideoIntelligenceServiceClient:
84
+ """Creates VideoIntelligenceServiceClient."""
85
+ return videointelligence.VideoIntelligenceServiceClient()
86
+
87
+ @override
88
+ def tag(
89
+ self,
90
+ name: str,
91
+ content: bytes,
92
+ tagging_options: base.TaggingOptions = base.TaggingOptions(),
93
+ ) -> base.TaggingResult:
94
+ request = videointelligence.AnnotateVideoRequest(
95
+ input_content=content,
96
+ video_context=videointelligence.VideoContext(
97
+ label_detection_config=videointelligence.LabelDetectionConfig(
98
+ frame_confidence_threshold=0.11,
99
+ label_detection_mode=(
100
+ videointelligence.LabelDetectionMode.SHOT_AND_FRAME_MODE
101
+ ),
102
+ )
103
+ ),
104
+ features=[videointelligence.Feature.LABEL_DETECTION],
105
+ )
106
+ operation = self.client.annotate_video(request)
107
+ response = operation.result(timeout=180)
108
+ tags_scores: dict[str, float] = defaultdict(float)
109
+ for frame_label in response.annotation_results[0].frame_label_annotations:
110
+ tags_scores[frame_label.entity.description] += sum(
111
+ c.confidence for c in frame_label.frames
112
+ )
113
+
114
+ tags = [
115
+ base.Tag(name=name, score=score) for name, score in tags_scores.items()
116
+ ]
117
+ if n_tags := tagging_options.n_tags:
118
+ tags = self._limit_number_of_tags(tags, n_tags)
119
+ return base.TaggingResult(identifier=name, type='video', content=tags)
@@ -0,0 +1,115 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Module for defining common interface for taggers."""
15
+
16
+ import abc
17
+ import dataclasses
18
+ from collections.abc import MutableSequence, Sequence
19
+ from typing import Literal
20
+
21
+ import pydantic
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class TaggingOptions:
26
+ """Specifies options to refine media tagging.
27
+
28
+ Attributes:
29
+ n_tags: Max number of tags to return.
30
+ tags: Particular tags to find in the media.
31
+ """
32
+
33
+ n_tags: int | None = None
34
+ tags: Sequence[str] | None = None
35
+
36
+ def __post_init__(self): # noqa: D105
37
+ if self.tags and not isinstance(self.tags, MutableSequence):
38
+ self.tags = [tag.strip() for tag in self.tags.split(',')]
39
+
40
+
41
+ class Tag(pydantic.BaseModel):
42
+ """Represents a single tag.
43
+
44
+ Attributes:
45
+ name: Descriptive name of the tag.
46
+ score: Score assigned to the tag.
47
+ """
48
+
49
+ name: str = pydantic.Field(description='tag_name')
50
+ score: float = pydantic.Field(description='tag_score from 0 to 1')
51
+
52
+
53
+ class Description(pydantic.BaseModel):
54
+ """Represents brief description of the media.
55
+
56
+ Attributes:
57
+ text: Textual description of the media.
58
+ """
59
+
60
+ text: str = pydantic.Field(description='brief description of the media')
61
+
62
+
63
+ class TaggingResult(pydantic.BaseModel):
64
+ """Contains tagging information for a given identifier.
65
+
66
+ Attributes:
67
+ identifier: Unique identifier of a media being tagged.
68
+ type: Type of media.
69
+ tags: Tags associated with a given media.
70
+ """
71
+
72
+ identifier: str = pydantic.Field(description='media identifier')
73
+ type: Literal['image', 'video'] = pydantic.Field(description='type of media')
74
+ content: list[Tag] | Description = pydantic.Field(
75
+ description='tags or description in the result'
76
+ )
77
+
78
+
79
+ class BaseTagger(abc.ABC):
80
+ """Interface to inherit all taggers from."""
81
+
82
+ @abc.abstractmethod
83
+ def tag(
84
+ self,
85
+ name: str,
86
+ content: bytes,
87
+ tagging_options: TaggingOptions = TaggingOptions(),
88
+ **kwargs: str,
89
+ ) -> TaggingResult:
90
+ """Sends media bytes to tagging engine.
91
+
92
+ Args:
93
+ name: Identifier of the media content being tagged.
94
+ content: Raw types of the media.
95
+ tagging_options: Parameters to refine the tagging results.
96
+ kwargs: Optional keywords arguments to be sent for tagging.
97
+
98
+ Returns:
99
+ Results of tagging.
100
+ """
101
+
102
+ def _limit_number_of_tags(
103
+ self, tags: Sequence[Tag], n_tags: int
104
+ ) -> list[Tag]:
105
+ """Returns limited number of tags from the pool.
106
+
107
+ Args:
108
+ tags: All tags produced by tagging algorithm.
109
+ n_tags: Max number of tags to return.
110
+
111
+ Returns:
112
+ Limited number of tags sorted by the score.
113
+ """
114
+ sorted_tags = sorted(tags, key=lambda x: x.score, reverse=True)
115
+ return sorted_tags[:n_tags]