media-tagging 0.2.0.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- media-tagging-0.2.0.dev1/PKG-INFO +13 -0
- media-tagging-0.2.0.dev1/README.md +48 -0
- media-tagging-0.2.0.dev1/entrypoints/__init__.py +0 -0
- media-tagging-0.2.0.dev1/entrypoints/cli.py +93 -0
- media-tagging-0.2.0.dev1/entrypoints/server.py +101 -0
- media-tagging-0.2.0.dev1/media_tagging/__init__.py +0 -0
- media-tagging-0.2.0.dev1/media_tagging/llms.py +63 -0
- media-tagging-0.2.0.dev1/media_tagging/tagger.py +60 -0
- media-tagging-0.2.0.dev1/media_tagging/taggers/__init__.py +0 -0
- media-tagging-0.2.0.dev1/media_tagging/taggers/api.py +119 -0
- media-tagging-0.2.0.dev1/media_tagging/taggers/base.py +115 -0
- media-tagging-0.2.0.dev1/media_tagging/taggers/llm.py +192 -0
- media-tagging-0.2.0.dev1/media_tagging/tools.py +70 -0
- media-tagging-0.2.0.dev1/media_tagging/utils.py +20 -0
- media-tagging-0.2.0.dev1/media_tagging/writer.py +153 -0
- media-tagging-0.2.0.dev1/media_tagging.egg-info/PKG-INFO +13 -0
- media-tagging-0.2.0.dev1/media_tagging.egg-info/SOURCES.txt +28 -0
- media-tagging-0.2.0.dev1/media_tagging.egg-info/dependency_links.txt +1 -0
- media-tagging-0.2.0.dev1/media_tagging.egg-info/entry_points.txt +2 -0
- media-tagging-0.2.0.dev1/media_tagging.egg-info/requires.txt +11 -0
- media-tagging-0.2.0.dev1/media_tagging.egg-info/top_level.txt +3 -0
- media-tagging-0.2.0.dev1/setup.cfg +4 -0
- media-tagging-0.2.0.dev1/setup.py +54 -0
- media-tagging-0.2.0.dev1/tests/__init__.py +0 -0
- media-tagging-0.2.0.dev1/tests/conftest.py +34 -0
- media-tagging-0.2.0.dev1/tests/end_to_end/__init__.py +0 -0
- media-tagging-0.2.0.dev1/tests/end_to_end/test_main.py +48 -0
- media-tagging-0.2.0.dev1/tests/unit/__init__.py +0 -0
- media-tagging-0.2.0.dev1/tests/unit/test_tagger.py +84 -0
- media-tagging-0.2.0.dev1/tests/unit/test_writer.py +119 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: media-tagging
|
|
3
|
+
Version: 0.2.0.dev1
|
|
4
|
+
Author: Google Inc. (gTech gPS CSE team)
|
|
5
|
+
Author-email: no-reply@google.com
|
|
6
|
+
License: Apache 2.0
|
|
7
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Welltech Media Tagging
|
|
2
|
+
|
|
3
|
+
## Prerequisites
|
|
4
|
+
|
|
5
|
+
* Google Cloud project with billing enabled.
|
|
6
|
+
* [Video Intelligence API](https://console.cloud.google.com/apis/library/videointelligence.googleapis.com) and [Vision API](https://console.cloud.google.com/apis/library/vision.googleapis.com) enabled.
|
|
7
|
+
* Python3.8+
|
|
8
|
+
* Access to repository configured. In order to clone this repository you need
|
|
9
|
+
to do the following:
|
|
10
|
+
* Visit https://professional-services.googlesource.com/new-password and
|
|
11
|
+
login with your account.
|
|
12
|
+
* Once authenticated please copy all lines in box
|
|
13
|
+
and paste them in the terminal.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
## Run
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
1. Install `media-tagger`
|
|
20
|
+
|
|
21
|
+
```
|
|
22
|
+
pip install media-tagging
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
2. Perform tagging
|
|
26
|
+
|
|
27
|
+
```
|
|
28
|
+
media-tagger --media-path MEDIA_PATH --tagger TAGGER_TYPE --writer WRITER_TYPE
|
|
29
|
+
```
|
|
30
|
+
where:
|
|
31
|
+
* MEDIA_PATH - comma-separated names of files for tagging (can be urls).
|
|
32
|
+
* TAGGER_TYPE - name of tagger, supported options:
|
|
33
|
+
* `vision-api` - tags images based on [Google Cloud Vision API](https://cloud.google.com/vision/),
|
|
34
|
+
* `video-api` for videos based on [Google Cloud Video Intelligence API](https://cloud.google.com/video-intelligence/)
|
|
35
|
+
* `gemini-image` - Uses Gemini to tags images. Add `--tagger.n_tags=<N_TAGS>`
|
|
36
|
+
parameter to control number of tags returned by tagger.
|
|
37
|
+
* `gemini-structured-image` - Uses Gemini to find certain tags in the images.
|
|
38
|
+
Add `--tagger.tags='tag1, tag2, ..., tagN` parameter to find certain tags
|
|
39
|
+
in the image.
|
|
40
|
+
* `gemini-description-image` - Provides brief description of the image,
|
|
41
|
+
* WRITER_TYPE - name of writer, one of `csv`, `json`
|
|
42
|
+
|
|
43
|
+
By default script will create a single file with tagging results for each media_path.
|
|
44
|
+
If you want to combine results into a single file add `--output OUTPUT_NAME` flag (without extension, i.e. `--output tagging_sample`.
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
## Disclaimer
|
|
48
|
+
This is not an officially supported Google product.
|
|
File without changes
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Provides CLI for media tagging."""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
import smart_open
|
|
21
|
+
from gaarf.cli import utils as gaarf_utils
|
|
22
|
+
|
|
23
|
+
from media_tagging import tagger, utils, writer
|
|
24
|
+
from media_tagging.taggers import base as base_tagger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def tag_media(
|
|
28
|
+
media_path: str | os.PathLike,
|
|
29
|
+
tagger_type: base_tagger.BaseTagger,
|
|
30
|
+
writer_type: writer.BaseWriter = writer.JsonWriter(),
|
|
31
|
+
single_output_name: str | None = None,
|
|
32
|
+
tagging_parameters: dict[str, str] | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Runs media tagging algorithm.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
media_path: Local or remote path to media file.
|
|
38
|
+
tagger_type: Initialized tagger.
|
|
39
|
+
writer_type: Initialized writer for saving tagging results.
|
|
40
|
+
single_output_name: Parameter for saving results to a single file.
|
|
41
|
+
tagging_parameters: Optional keywords arguments to be sent for tagging.
|
|
42
|
+
"""
|
|
43
|
+
media_paths = media_path.split(',')
|
|
44
|
+
if not tagging_parameters:
|
|
45
|
+
tagging_parameters = {}
|
|
46
|
+
results = []
|
|
47
|
+
for path in media_paths:
|
|
48
|
+
media_name = utils.convert_path_to_media_name(path)
|
|
49
|
+
logging.info('Processing media: %s', path)
|
|
50
|
+
with smart_open.open(path, 'rb') as f:
|
|
51
|
+
media_bytes = f.read()
|
|
52
|
+
results.append(
|
|
53
|
+
tagger_type.tag(
|
|
54
|
+
media_name,
|
|
55
|
+
media_bytes,
|
|
56
|
+
tagging_options=base_tagger.TaggingOptions(**tagging_parameters),
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
writer_type.write(results, single_output_name)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def main():
|
|
63
|
+
"""Main entrypoint."""
|
|
64
|
+
parser = argparse.ArgumentParser()
|
|
65
|
+
parser.add_argument('--media-path', dest='media_path')
|
|
66
|
+
parser.add_argument('--tagger', dest='tagger', default='vision-api')
|
|
67
|
+
parser.add_argument('--writer', dest='writer', default='json')
|
|
68
|
+
parser.add_argument('--output-to-file', dest='output', default=None)
|
|
69
|
+
parser.add_argument('--loglevel', dest='loglevel', default='INFO')
|
|
70
|
+
args, kwargs = parser.parse_known_args()
|
|
71
|
+
|
|
72
|
+
concrete_tagger = tagger.create_tagger(args.tagger)
|
|
73
|
+
concrete_writer = writer.create_writer(args.writer)
|
|
74
|
+
tagging_parameters = gaarf_utils.ParamsParser(['tagger']).parse(kwargs)
|
|
75
|
+
|
|
76
|
+
logging.basicConfig(
|
|
77
|
+
format='[%(asctime)s][%(name)s][%(levelname)s] %(message)s',
|
|
78
|
+
level=args.loglevel,
|
|
79
|
+
datefmt='%Y-%m-%d %H:%M:%S',
|
|
80
|
+
)
|
|
81
|
+
logging.getLogger(__file__)
|
|
82
|
+
|
|
83
|
+
tag_media(
|
|
84
|
+
media_path=args.media_path,
|
|
85
|
+
tagger_type=concrete_tagger,
|
|
86
|
+
writer_type=concrete_writer,
|
|
87
|
+
single_output_name=args.output,
|
|
88
|
+
tagging_parameters=tagging_parameters.get('tagger'),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == '__main__':
|
|
93
|
+
main()
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Provides HTTP endpoint for media tagging."""
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
import fastapi
|
|
19
|
+
import smart_open
|
|
20
|
+
from typing_extensions import TypedDict
|
|
21
|
+
|
|
22
|
+
from media_tagging import tagger, utils
|
|
23
|
+
from media_tagging.tagger import base as base_tagger
|
|
24
|
+
|
|
25
|
+
taggers: dict[str, base_tagger.BaseTagger] = {}
|
|
26
|
+
app = fastapi.FastAPI()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MediaPostRequest(TypedDict):
|
|
30
|
+
"""Specifies structure of request for tagging media.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
media_url: Local or remote URL of media.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
media_url: str
|
|
37
|
+
tagging_parameters: dict[str, int | list[str]]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@app.post('/tagger/llm')
|
|
41
|
+
async def tag_with_llm(
|
|
42
|
+
data: MediaPostRequest = fastapi.Body(embed=True),
|
|
43
|
+
) -> dict[str, str]:
|
|
44
|
+
"""Performs media tagging via LLMs.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
data: Post request for media tagging.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Json results of tagging.
|
|
51
|
+
"""
|
|
52
|
+
if not (llm_tagger := taggers.get('gemini-image')):
|
|
53
|
+
llm_tagger = tagger.create_tagger('gemini-image')
|
|
54
|
+
taggers['gemini-image'] = llm_tagger
|
|
55
|
+
if media_url := data.get('media_url'):
|
|
56
|
+
media_name = utils.convert_path_to_media_name(media_url)
|
|
57
|
+
logging.info('Processing media: %s', media_url)
|
|
58
|
+
with smart_open.open(media_url, 'rb') as f:
|
|
59
|
+
media_bytes = f.read()
|
|
60
|
+
tagging_options = base_tagger.TaggingOptions(
|
|
61
|
+
**data.get('tagging_parameters')
|
|
62
|
+
)
|
|
63
|
+
tagging_result = llm_tagger.tag(
|
|
64
|
+
name=media_name, content=media_bytes, tagging_options=tagging_options
|
|
65
|
+
)
|
|
66
|
+
return fastapi.responses.JSONResponse(
|
|
67
|
+
content=fastapi.encoders.jsonable_encoder(tagging_result.dict())
|
|
68
|
+
)
|
|
69
|
+
raise ValueError('No path to media is provided.')
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@app.post('/tagger/api')
|
|
73
|
+
async def tag_with_api(
|
|
74
|
+
data: MediaPostRequest = fastapi.Body(embed=True),
|
|
75
|
+
) -> dict[str, str]:
|
|
76
|
+
"""Performs media tagging via Google Cloud APIs.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
data: Post request for media tagging.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Json results of tagging.
|
|
83
|
+
"""
|
|
84
|
+
if not (api_tagger := taggers.get('vision-api')):
|
|
85
|
+
api_tagger = tagger.create_tagger('vision-api')
|
|
86
|
+
taggers['vision-api'] = api_tagger
|
|
87
|
+
if media_url := data.get('media_url'):
|
|
88
|
+
media_name = utils.convert_path_to_media_name(media_url)
|
|
89
|
+
logging.info('Processing media: %s', media_url)
|
|
90
|
+
with smart_open.open(media_url, 'rb') as f:
|
|
91
|
+
media_bytes = f.read()
|
|
92
|
+
tagging_options = base_tagger.TaggingOptions(
|
|
93
|
+
**data.get('tagging_parameters')
|
|
94
|
+
)
|
|
95
|
+
tagging_result = api_tagger.tag(
|
|
96
|
+
name=media_name, content=media_bytes, tagging_options=tagging_options
|
|
97
|
+
)
|
|
98
|
+
return fastapi.responses.JSONResponse(
|
|
99
|
+
content=fastapi.encoders.jsonable_encoder(tagging_result.dict())
|
|
100
|
+
)
|
|
101
|
+
raise ValueError('No path to media is provided.')
|
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Module for defining various LLMs."""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
import langchain_google_genai as genai
|
|
21
|
+
from langchain_core import language_models
|
|
22
|
+
|
|
23
|
+
_GEMINI_SAFETY_SETTINGS: dict[genai.HarmCategory, genai.HarmBlockThreshold] = {
|
|
24
|
+
genai.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.HarmBlockThreshold.BLOCK_NONE, # noqa: E501
|
|
25
|
+
genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.HarmBlockThreshold.BLOCK_NONE, # noqa: E501
|
|
26
|
+
genai.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.HarmBlockThreshold.BLOCK_NONE, # noqa: E501
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_llm(
|
|
31
|
+
llm_type: str, llm_parameters: dict[str, str] | None = None
|
|
32
|
+
) -> language_models.BaseLanguageModel:
|
|
33
|
+
"""Creates LLM based on type and parameters.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
llm_type: Type of LLM to instantiate.
|
|
37
|
+
llm_parameters: Various parameters to instantiate LLM.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Initialized LLM.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
InvalidLLMTypeError: When incorrect LLM type is specified.
|
|
44
|
+
"""
|
|
45
|
+
mapping = {
|
|
46
|
+
'gemini': genai.ChatGoogleGenerativeAI,
|
|
47
|
+
}
|
|
48
|
+
if llm := mapping.get(llm_type):
|
|
49
|
+
if not llm_parameters:
|
|
50
|
+
llm_parameters = {}
|
|
51
|
+
if llm_type == 'gemini':
|
|
52
|
+
llm_parameters.update(
|
|
53
|
+
{
|
|
54
|
+
'safety_settings': _GEMINI_SAFETY_SETTINGS,
|
|
55
|
+
'google_api_key': os.environ.get('GOOGLE_API_KEY'),
|
|
56
|
+
}
|
|
57
|
+
)
|
|
58
|
+
return llm(**llm_parameters)
|
|
59
|
+
raise InvalidLLMTypeError(f'Unsupported LLM type: {llm_type}')
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class InvalidLLMTypeError(Exception):
|
|
63
|
+
"""Error when incorrect LLM type is specified."""
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Module for performing media tagging.
|
|
15
|
+
|
|
16
|
+
Media tagging sends API requests to tagging engine (i.e. Google Vision API)
|
|
17
|
+
and returns tagging results that can be easily written.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from media_tagging.taggers import api, base, llm
|
|
21
|
+
|
|
22
|
+
_TAGGERS = {
|
|
23
|
+
'vision-api': api.GoogleVisionAPITagger,
|
|
24
|
+
'video-api': api.GoogleVideoIntelligenceAPITagger,
|
|
25
|
+
'gemini-image': llm.GeminiImageTagger,
|
|
26
|
+
'gemini-structured-image': llm.GeminiImageTagger,
|
|
27
|
+
'gemini-description-image': llm.GeminiImageTagger,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
_LLM_TAGGERS_TYPES = {
|
|
31
|
+
'gemini-image': llm.LLMTaggerTypeEnum.UNSTRUCTURED,
|
|
32
|
+
'gemini-structured-image': llm.LLMTaggerTypeEnum.STRUCTURED,
|
|
33
|
+
'gemini-description-image': llm.LLMTaggerTypeEnum.DESCRIPTION,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_tagger(
|
|
38
|
+
tagger_type: str, tagger_parameters: dict[str, str] | None = None
|
|
39
|
+
) -> base.BaseTagger:
|
|
40
|
+
"""Factory for creating taggers based on provided type.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
tagger_type: Type of tagger.
|
|
44
|
+
tagger_parameters: Various parameters to instantiate tagger.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Concrete tagger class.
|
|
48
|
+
"""
|
|
49
|
+
if not tagger_parameters:
|
|
50
|
+
tagger_parameters = {}
|
|
51
|
+
if tagger := _TAGGERS.get(tagger_type):
|
|
52
|
+
if issubclass(tagger, llm.LLMTagger):
|
|
53
|
+
return tagger(
|
|
54
|
+
tagger_type=_LLM_TAGGERS_TYPES.get(tagger_type), **tagger_parameters
|
|
55
|
+
)
|
|
56
|
+
return tagger(**tagger_parameters)
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f'Incorrect tagger {type} is provided, '
|
|
59
|
+
f'valid options: {list(_TAGGERS.keys())}'
|
|
60
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Module for performing media tagging via Google APIs.
|
|
15
|
+
|
|
16
|
+
Media tagging sends API requests to tagging engine (i.e. Google Vision API)
|
|
17
|
+
and returns tagging results that can be easily written.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from collections import defaultdict
|
|
21
|
+
|
|
22
|
+
from google.cloud import videointelligence, vision
|
|
23
|
+
from typing_extensions import override
|
|
24
|
+
|
|
25
|
+
from media_tagging.taggers import base
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GoogleVisionAPITagger(base.BaseTagger):
|
|
29
|
+
"""Tagger responsible for getting image tags from Cloud Vision API.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
client: Vision API client responsible to tagging.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, project: str | None = None) -> None:
|
|
36
|
+
"""Initializes GoogleVisionAPITagger with a given project.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
project: Google Cloud project id.
|
|
40
|
+
"""
|
|
41
|
+
self._project = project
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def client(self) -> vision.ImageAnnotatorClient:
|
|
45
|
+
"""Creates ImageAnnotatorClient."""
|
|
46
|
+
return vision.ImageAnnotatorClient()
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
def tag(
|
|
50
|
+
self,
|
|
51
|
+
name: str,
|
|
52
|
+
content: bytes,
|
|
53
|
+
tagging_options: base.TaggingOptions = base.TaggingOptions(),
|
|
54
|
+
**kwargs: str,
|
|
55
|
+
) -> base.TaggingResult:
|
|
56
|
+
image = vision.Image(content=content)
|
|
57
|
+
response = self.client.label_detection(image=image)
|
|
58
|
+
tags = [
|
|
59
|
+
base.Tag(name=r.description, score=r.score)
|
|
60
|
+
for r in response.label_annotations
|
|
61
|
+
]
|
|
62
|
+
if n_tags := tagging_options.n_tags:
|
|
63
|
+
tags = self._limit_number_of_tags(tags, n_tags)
|
|
64
|
+
return base.TaggingResult(identifier=name, type='image', content=tags)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GoogleVideoIntelligenceAPITagger(base.BaseTagger):
|
|
68
|
+
"""Tagger responsible for getting image tags from Video Intelligence API.
|
|
69
|
+
|
|
70
|
+
Attributes:
|
|
71
|
+
client: Video Intelligence API client responsible to tagging.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, project: str | None = None) -> None:
|
|
75
|
+
"""Initializes GoogleVideoIntelligenceAPITagger with a given project.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
project: Google Cloud project id.
|
|
79
|
+
"""
|
|
80
|
+
self._project = project
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def client(self) -> videointelligence.VideoIntelligenceServiceClient:
|
|
84
|
+
"""Creates VideoIntelligenceServiceClient."""
|
|
85
|
+
return videointelligence.VideoIntelligenceServiceClient()
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def tag(
|
|
89
|
+
self,
|
|
90
|
+
name: str,
|
|
91
|
+
content: bytes,
|
|
92
|
+
tagging_options: base.TaggingOptions = base.TaggingOptions(),
|
|
93
|
+
) -> base.TaggingResult:
|
|
94
|
+
request = videointelligence.AnnotateVideoRequest(
|
|
95
|
+
input_content=content,
|
|
96
|
+
video_context=videointelligence.VideoContext(
|
|
97
|
+
label_detection_config=videointelligence.LabelDetectionConfig(
|
|
98
|
+
frame_confidence_threshold=0.11,
|
|
99
|
+
label_detection_mode=(
|
|
100
|
+
videointelligence.LabelDetectionMode.SHOT_AND_FRAME_MODE
|
|
101
|
+
),
|
|
102
|
+
)
|
|
103
|
+
),
|
|
104
|
+
features=[videointelligence.Feature.LABEL_DETECTION],
|
|
105
|
+
)
|
|
106
|
+
operation = self.client.annotate_video(request)
|
|
107
|
+
response = operation.result(timeout=180)
|
|
108
|
+
tags_scores: dict[str, float] = defaultdict(float)
|
|
109
|
+
for frame_label in response.annotation_results[0].frame_label_annotations:
|
|
110
|
+
tags_scores[frame_label.entity.description] += sum(
|
|
111
|
+
c.confidence for c in frame_label.frames
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
tags = [
|
|
115
|
+
base.Tag(name=name, score=score) for name, score in tags_scores.items()
|
|
116
|
+
]
|
|
117
|
+
if n_tags := tagging_options.n_tags:
|
|
118
|
+
tags = self._limit_number_of_tags(tags, n_tags)
|
|
119
|
+
return base.TaggingResult(identifier=name, type='video', content=tags)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Module for defining common interface for taggers."""
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
import dataclasses
|
|
18
|
+
from collections.abc import MutableSequence, Sequence
|
|
19
|
+
from typing import Literal
|
|
20
|
+
|
|
21
|
+
import pydantic
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclasses.dataclass
|
|
25
|
+
class TaggingOptions:
|
|
26
|
+
"""Specifies options to refine media tagging.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
n_tags: Max number of tags to return.
|
|
30
|
+
tags: Particular tags to find in the media.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
n_tags: int | None = None
|
|
34
|
+
tags: Sequence[str] | None = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self): # noqa: D105
|
|
37
|
+
if self.tags and not isinstance(self.tags, MutableSequence):
|
|
38
|
+
self.tags = [tag.strip() for tag in self.tags.split(',')]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Tag(pydantic.BaseModel):
|
|
42
|
+
"""Represents a single tag.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
name: Descriptive name of the tag.
|
|
46
|
+
score: Score assigned to the tag.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
name: str = pydantic.Field(description='tag_name')
|
|
50
|
+
score: float = pydantic.Field(description='tag_score from 0 to 1')
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Description(pydantic.BaseModel):
|
|
54
|
+
"""Represents brief description of the media.
|
|
55
|
+
|
|
56
|
+
Attributes:
|
|
57
|
+
text: Textual description of the media.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
text: str = pydantic.Field(description='brief description of the media')
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TaggingResult(pydantic.BaseModel):
|
|
64
|
+
"""Contains tagging information for a given identifier.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
identifier: Unique identifier of a media being tagged.
|
|
68
|
+
type: Type of media.
|
|
69
|
+
tags: Tags associated with a given media.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
identifier: str = pydantic.Field(description='media identifier')
|
|
73
|
+
type: Literal['image', 'video'] = pydantic.Field(description='type of media')
|
|
74
|
+
content: list[Tag] | Description = pydantic.Field(
|
|
75
|
+
description='tags or description in the result'
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class BaseTagger(abc.ABC):
|
|
80
|
+
"""Interface to inherit all taggers from."""
|
|
81
|
+
|
|
82
|
+
@abc.abstractmethod
|
|
83
|
+
def tag(
|
|
84
|
+
self,
|
|
85
|
+
name: str,
|
|
86
|
+
content: bytes,
|
|
87
|
+
tagging_options: TaggingOptions = TaggingOptions(),
|
|
88
|
+
**kwargs: str,
|
|
89
|
+
) -> TaggingResult:
|
|
90
|
+
"""Sends media bytes to tagging engine.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
name: Identifier of the media content being tagged.
|
|
94
|
+
content: Raw types of the media.
|
|
95
|
+
tagging_options: Parameters to refine the tagging results.
|
|
96
|
+
kwargs: Optional keywords arguments to be sent for tagging.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Results of tagging.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def _limit_number_of_tags(
|
|
103
|
+
self, tags: Sequence[Tag], n_tags: int
|
|
104
|
+
) -> list[Tag]:
|
|
105
|
+
"""Returns limited number of tags from the pool.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
tags: All tags produced by tagging algorithm.
|
|
109
|
+
n_tags: Max number of tags to return.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Limited number of tags sorted by the score.
|
|
113
|
+
"""
|
|
114
|
+
sorted_tags = sorted(tags, key=lambda x: x.score, reverse=True)
|
|
115
|
+
return sorted_tags[:n_tags]
|