datashare-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datashare_python/__init__.py +0 -0
- datashare_python/__main__.py +4 -0
- datashare_python/app.py +85 -0
- datashare_python/cli/__init__.py +30 -0
- datashare_python/cli/tasks.py +182 -0
- datashare_python/cli/utils.py +33 -0
- datashare_python/config.py +60 -0
- datashare_python/constants.py +6 -0
- datashare_python/objects.py +49 -0
- datashare_python/task_client.py +124 -0
- datashare_python/tasks/__init__.py +2 -0
- datashare_python/tasks/classify_docs.py +227 -0
- datashare_python/tasks/dependencies.py +110 -0
- datashare_python/tasks/translate_docs.py +223 -0
- datashare_python/utils.py +69 -0
- datashare_python-0.1.0.dist-info/METADATA +80 -0
- datashare_python-0.1.0.dist-info/RECORD +19 -0
- datashare_python-0.1.0.dist-info/WHEEL +4 -0
- datashare_python-0.1.0.dist-info/entry_points.txt +4 -0
|
File without changes
|
datashare_python/app.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from icij_worker import AsyncApp
|
|
4
|
+
from icij_worker.typing_ import PercentProgress
|
|
5
|
+
from pydantic import parse_obj_as
|
|
6
|
+
|
|
7
|
+
from datashare_python.constants import PYTHON_TASK_GROUP
|
|
8
|
+
from datashare_python.objects import ClassificationConfig, TranslationConfig
|
|
9
|
+
from datashare_python.tasks import (
|
|
10
|
+
classify_docs as classify_docs_,
|
|
11
|
+
create_classification_tasks as create_classification_tasks_,
|
|
12
|
+
create_translation_tasks as create_translation_tasks_,
|
|
13
|
+
translate_docs as translate_docs_,
|
|
14
|
+
)
|
|
15
|
+
from datashare_python.tasks.dependencies import APP_LIFESPAN_DEPS
|
|
16
|
+
|
|
17
|
+
app = AsyncApp("ml", dependencies=APP_LIFESPAN_DEPS)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@app.task(group=PYTHON_TASK_GROUP)
|
|
21
|
+
async def create_translation_tasks(
|
|
22
|
+
project: str,
|
|
23
|
+
target_language: str,
|
|
24
|
+
config: dict | None = None,
|
|
25
|
+
user: dict | None = None, # pylint: disable=unused-argument
|
|
26
|
+
) -> list[str]:
|
|
27
|
+
# Parse the incoming config
|
|
28
|
+
config = parse_obj_as(Optional[TranslationConfig], config)
|
|
29
|
+
return await create_translation_tasks_(
|
|
30
|
+
project=project, target_language=target_language, config=config
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@app.task(group=PYTHON_TASK_GROUP)
|
|
35
|
+
async def translate_docs(
|
|
36
|
+
docs: list[str],
|
|
37
|
+
project: str,
|
|
38
|
+
target_language: str,
|
|
39
|
+
progress: PercentProgress,
|
|
40
|
+
config: dict | None = None,
|
|
41
|
+
user: dict | None = None, # pylint: disable=unused-argument
|
|
42
|
+
) -> int:
|
|
43
|
+
config = parse_obj_as(Optional[TranslationConfig], config)
|
|
44
|
+
return await translate_docs_(
|
|
45
|
+
docs, target_language, project=project, config=config, progress=progress
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@app.task(group=PYTHON_TASK_GROUP)
|
|
50
|
+
async def create_classification_tasks(
|
|
51
|
+
project: str,
|
|
52
|
+
language: str,
|
|
53
|
+
n_workers: int,
|
|
54
|
+
progress: PercentProgress,
|
|
55
|
+
config: dict | None = None,
|
|
56
|
+
user: dict | None = None, # pylint: disable=unused-argument
|
|
57
|
+
) -> list[str]:
|
|
58
|
+
config = parse_obj_as(Optional[ClassificationConfig], config)
|
|
59
|
+
return await create_classification_tasks_(
|
|
60
|
+
project=project,
|
|
61
|
+
language=language,
|
|
62
|
+
n_workers=n_workers,
|
|
63
|
+
config=config,
|
|
64
|
+
progress=progress,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@app.task(group=PYTHON_TASK_GROUP)
|
|
69
|
+
async def classify_docs(
|
|
70
|
+
docs: list[str],
|
|
71
|
+
language: str,
|
|
72
|
+
project: str,
|
|
73
|
+
progress: PercentProgress,
|
|
74
|
+
config: dict | None = None,
|
|
75
|
+
user: dict | None = None, # pylint: disable=unused-argument
|
|
76
|
+
) -> int:
|
|
77
|
+
config = parse_obj_as(Optional[ClassificationConfig], config)
|
|
78
|
+
return await classify_docs_(
|
|
79
|
+
docs, language=language, project=project, config=config, progress=progress
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@app.task(group=PYTHON_TASK_GROUP)
|
|
84
|
+
def ping() -> str:
|
|
85
|
+
return "pong"
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
from typing import Annotated, Optional
|
|
3
|
+
|
|
4
|
+
import typer
|
|
5
|
+
|
|
6
|
+
import datashare_python
|
|
7
|
+
from datashare_python.cli.tasks import task_app
|
|
8
|
+
from datashare_python.cli.utils import AsyncTyper
|
|
9
|
+
|
|
10
|
+
cli_app = AsyncTyper(context_settings={"help_option_names": ["-h", "--help"]})
|
|
11
|
+
cli_app.add_typer(task_app)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def version_callback(value: bool):
|
|
15
|
+
if value:
|
|
16
|
+
package_version = importlib.metadata.version(datashare_python.__name__)
|
|
17
|
+
print(package_version)
|
|
18
|
+
raise typer.Exit()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@cli_app.callback(name="datashare-python")
|
|
22
|
+
def main(
|
|
23
|
+
version: Annotated[ # pylint: disable=unused-argument
|
|
24
|
+
Optional[bool],
|
|
25
|
+
typer.Option( # pylint: disable=unused-argument
|
|
26
|
+
"--version", callback=version_callback, is_eager=True
|
|
27
|
+
),
|
|
28
|
+
] = None
|
|
29
|
+
):
|
|
30
|
+
"""Datashare Python CLI"""
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from traceback import FrameSummary, StackSummary
|
|
7
|
+
from typing import Annotated, Any, Optional
|
|
8
|
+
|
|
9
|
+
import typer
|
|
10
|
+
from alive_progress import alive_bar
|
|
11
|
+
from icij_worker import TaskState
|
|
12
|
+
from icij_worker.objects import READY_STATES, Task, TaskError
|
|
13
|
+
|
|
14
|
+
from datashare_python.cli.utils import AsyncTyper, eprint
|
|
15
|
+
from datashare_python.constants import PYTHON_TASK_GROUP
|
|
16
|
+
from datashare_python.task_client import DatashareTaskClient
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
DEFAULT_DS_ADDRESS = "http://localhost:8080"
|
|
21
|
+
|
|
22
|
+
_ARGS_HELP = "task argument as a JSON string or file path"
|
|
23
|
+
_GROUP_HELP = "task group"
|
|
24
|
+
_DS_API_KEY_HELP = "datashare API key"
|
|
25
|
+
_DS_URL_HELP = "datashare address"
|
|
26
|
+
_POLLING_INTERVAL_S_HELP = "task state polling interval in seconds"
|
|
27
|
+
_NAME_HELP = "registered task name"
|
|
28
|
+
_RESULT_HELP = "get a task result"
|
|
29
|
+
_START_HELP = "creates a new task and start it"
|
|
30
|
+
_TASK_ID_HELP = "task ID"
|
|
31
|
+
_WATCH_HELP = "watch a task until it's complete"
|
|
32
|
+
|
|
33
|
+
TaskArgs = str
|
|
34
|
+
|
|
35
|
+
task_app = AsyncTyper(name="task")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@task_app.async_command(help=_START_HELP)
|
|
39
|
+
async def start(
|
|
40
|
+
name: Annotated[str, typer.Argument(help=_NAME_HELP)],
|
|
41
|
+
args: Annotated[TaskArgs, typer.Argument(help=_ARGS_HELP)] = None,
|
|
42
|
+
group: Annotated[
|
|
43
|
+
Optional[str],
|
|
44
|
+
typer.Option("--group", "-g", help=_GROUP_HELP),
|
|
45
|
+
] = PYTHON_TASK_GROUP.name,
|
|
46
|
+
ds_address: Annotated[
|
|
47
|
+
str, typer.Option("--ds-address", "-a", help=_DS_URL_HELP)
|
|
48
|
+
] = DEFAULT_DS_ADDRESS,
|
|
49
|
+
ds_api_key: Annotated[
|
|
50
|
+
Optional[str], typer.Option("--ds-api-key", "-k", help=_DS_API_KEY_HELP)
|
|
51
|
+
] = None,
|
|
52
|
+
):
|
|
53
|
+
match args:
|
|
54
|
+
case str():
|
|
55
|
+
as_path = Path(name)
|
|
56
|
+
if as_path.exists():
|
|
57
|
+
args = json.loads(as_path.read_text())
|
|
58
|
+
else:
|
|
59
|
+
args = json.loads(args)
|
|
60
|
+
case None:
|
|
61
|
+
args = dict()
|
|
62
|
+
case _:
|
|
63
|
+
raise TypeError(f"Invalid args {args}")
|
|
64
|
+
client = DatashareTaskClient(ds_address, api_key=ds_api_key)
|
|
65
|
+
async with client:
|
|
66
|
+
task_id = await client.create_task(name, args, group=group)
|
|
67
|
+
eprint(f"Task({task_id}) started !")
|
|
68
|
+
eprint(f"Task({task_id}) 🛫")
|
|
69
|
+
print(task_id)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@task_app.async_command(help=_WATCH_HELP)
|
|
73
|
+
async def watch(
|
|
74
|
+
task_id: Annotated[str, typer.Argument(help=_TASK_ID_HELP)],
|
|
75
|
+
ds_address: Annotated[
|
|
76
|
+
str, typer.Option("--ds-address", "-a", help=_DS_URL_HELP)
|
|
77
|
+
] = DEFAULT_DS_ADDRESS,
|
|
78
|
+
ds_api_key: Annotated[
|
|
79
|
+
Optional[str], typer.Option("--ds-api-key", "-k", help=_DS_API_KEY_HELP)
|
|
80
|
+
] = None,
|
|
81
|
+
polling_interval_s: Annotated[
|
|
82
|
+
float, typer.Option("--polling-interval-s", "-p", help=_POLLING_INTERVAL_S_HELP)
|
|
83
|
+
] = 1.0,
|
|
84
|
+
):
|
|
85
|
+
client = DatashareTaskClient(ds_address, api_key=ds_api_key)
|
|
86
|
+
async with client:
|
|
87
|
+
task = await client.get_task(task_id)
|
|
88
|
+
if task.state is READY_STATES:
|
|
89
|
+
await _handle_ready(task, client, already_done=True)
|
|
90
|
+
await _handle_alive(task, client, polling_interval_s)
|
|
91
|
+
print(task_id)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@task_app.async_command(help=_RESULT_HELP)
|
|
95
|
+
async def result(
|
|
96
|
+
task_id: Annotated[str, typer.Argument(help=_TASK_ID_HELP)],
|
|
97
|
+
ds_address: Annotated[
|
|
98
|
+
str, typer.Option("--ds-address", "-a", help=_DS_URL_HELP)
|
|
99
|
+
] = DEFAULT_DS_ADDRESS,
|
|
100
|
+
ds_api_key: Annotated[
|
|
101
|
+
Optional[str], typer.Option("--ds-api-key", "-k", help=_DS_API_KEY_HELP)
|
|
102
|
+
] = None,
|
|
103
|
+
) -> Any:
|
|
104
|
+
client = DatashareTaskClient(ds_address, api_key=ds_api_key)
|
|
105
|
+
async with client:
|
|
106
|
+
res = await client.get_task_result(task_id)
|
|
107
|
+
if isinstance(res, (dict, list)):
|
|
108
|
+
res = json.dumps(res, indent=2)
|
|
109
|
+
print(res)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
async def _handle_ready(
|
|
113
|
+
task: Task, client: DatashareTaskClient, already_done: bool = False
|
|
114
|
+
) -> None:
|
|
115
|
+
match task.state:
|
|
116
|
+
case TaskState.ERROR:
|
|
117
|
+
await _handle_error(task, client)
|
|
118
|
+
case TaskState.CANCELLED:
|
|
119
|
+
await _handle_cancelled(task)
|
|
120
|
+
case TaskState.DONE:
|
|
121
|
+
if already_done:
|
|
122
|
+
await _handle_already_done(task)
|
|
123
|
+
else:
|
|
124
|
+
await _handle_done(task)
|
|
125
|
+
case _:
|
|
126
|
+
raise ValueError(f"Unexpected task state {task.state}")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
async def _handle_error(task, client: DatashareTaskClient):
|
|
130
|
+
error = await client.get_task_error(task.id)
|
|
131
|
+
eprint(
|
|
132
|
+
f"Task({task.id}) failed with the following"
|
|
133
|
+
f" error:\n\n{_format_error(error)}"
|
|
134
|
+
)
|
|
135
|
+
eprint(f"Task({task.id}) ❌")
|
|
136
|
+
raise typer.Exit(code=1)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def _handle_cancelled(task):
|
|
140
|
+
eprint(f"Task({task.id}) was cancelled !")
|
|
141
|
+
eprint(f"Task({task.id}) 🛑")
|
|
142
|
+
raise typer.Exit(code=1)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def _handle_already_done(task):
|
|
146
|
+
eprint(f"Task({task.id}) ✅ is already completed !")
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def _handle_done(task):
|
|
150
|
+
eprint(f"Task({task.id}) 🛬")
|
|
151
|
+
eprint(f"Task({task.id}) ✅")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _handle_alive(
|
|
155
|
+
task: Task, client: DatashareTaskClient, polling_interval_s: float
|
|
156
|
+
) -> None:
|
|
157
|
+
title = f"Task({task.id}) 🛫"
|
|
158
|
+
stats = "(ETA: {eta})"
|
|
159
|
+
monitor = "{percent}"
|
|
160
|
+
progress_bar = alive_bar(
|
|
161
|
+
title=title, manual=True, stats=stats, monitor=monitor, file=sys.stderr
|
|
162
|
+
)
|
|
163
|
+
with progress_bar as bar:
|
|
164
|
+
task_state = task.state
|
|
165
|
+
while task_state not in READY_STATES:
|
|
166
|
+
task = await client.get_task(task.id)
|
|
167
|
+
task_state = task.state
|
|
168
|
+
progress = task.progress or 0.0
|
|
169
|
+
bar(progress) # pylint: disable=not-callable
|
|
170
|
+
await asyncio.sleep(polling_interval_s)
|
|
171
|
+
if task_state in READY_STATES:
|
|
172
|
+
await _handle_ready(task, client)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _format_error(error: TaskError) -> str:
|
|
176
|
+
stack = StackSummary.from_list(
|
|
177
|
+
[FrameSummary(f.name, f.lineno, f.name) for f in error.stacktrace]
|
|
178
|
+
)
|
|
179
|
+
msg = f"{error.name}:\n{stack}\n{error.message}"
|
|
180
|
+
if error.cause:
|
|
181
|
+
msg += "\n cause by {error.cause}"
|
|
182
|
+
return msg
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import concurrent.futures
|
|
3
|
+
import sys
|
|
4
|
+
from functools import wraps
|
|
5
|
+
|
|
6
|
+
import typer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AsyncTyper(typer.Typer):
|
|
10
|
+
def async_command(self, *args, **kwargs):
|
|
11
|
+
def decorator(async_func):
|
|
12
|
+
@wraps(async_func)
|
|
13
|
+
def sync_func(*_args, **_kwargs):
|
|
14
|
+
res = asyncio.run(async_func(*_args, **_kwargs))
|
|
15
|
+
return res
|
|
16
|
+
|
|
17
|
+
self.command(*args, **kwargs)(sync_func)
|
|
18
|
+
return async_func
|
|
19
|
+
|
|
20
|
+
return decorator
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def eprint(*args, **kwargs):
|
|
24
|
+
print(*args, file=sys.stderr, **kwargs)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _to_concurrent(
|
|
28
|
+
fut: asyncio.Future, loop: asyncio.AbstractEventLoop
|
|
29
|
+
) -> concurrent.futures.Future:
|
|
30
|
+
async def wait():
|
|
31
|
+
await fut
|
|
32
|
+
|
|
33
|
+
return asyncio.run_coroutine_threadsafe(wait(), loop)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import ClassVar
|
|
2
|
+
|
|
3
|
+
from icij_common.pydantic_utils import ICIJSettings, NoEnumModel
|
|
4
|
+
from icij_worker.utils.logging_ import LogWithWorkerIDMixin
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
import datashare_python
|
|
8
|
+
|
|
9
|
+
_ALL_LOGGERS = [datashare_python.__name__]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AppConfig(ICIJSettings, LogWithWorkerIDMixin, NoEnumModel):
|
|
13
|
+
class Config:
|
|
14
|
+
env_prefix = "DS_DOCKER_ML_"
|
|
15
|
+
|
|
16
|
+
loggers: ClassVar[list[str]] = Field(_ALL_LOGGERS, const=True)
|
|
17
|
+
|
|
18
|
+
log_level: str = Field(default="INFO")
|
|
19
|
+
|
|
20
|
+
batch_size: int = 1024
|
|
21
|
+
pipeline_batch_size: int = 1024
|
|
22
|
+
ne_buffer_size: int = 1000
|
|
23
|
+
|
|
24
|
+
# DS
|
|
25
|
+
ds_api_key: str | None = None
|
|
26
|
+
ds_url: str = "http://datashare:8080"
|
|
27
|
+
# ES
|
|
28
|
+
es_address: str = "http://localhost:9200"
|
|
29
|
+
es_default_page_size: int = 1000
|
|
30
|
+
es_keep_alive: str = "10m"
|
|
31
|
+
es_max_concurrency: int = 5
|
|
32
|
+
es_max_retries: int = 0
|
|
33
|
+
es_max_retry_wait_s: int | float = 60
|
|
34
|
+
es_timeout_s: int | float = 60 * 5
|
|
35
|
+
|
|
36
|
+
def to_es_client(self, address: str | None = None) -> "ESClient":
|
|
37
|
+
from icij_common.es import ESClient
|
|
38
|
+
|
|
39
|
+
if address is None:
|
|
40
|
+
address = self.es_address
|
|
41
|
+
|
|
42
|
+
client = ESClient(
|
|
43
|
+
hosts=[address],
|
|
44
|
+
pagination=self.es_default_page_size,
|
|
45
|
+
max_concurrency=self.es_max_concurrency,
|
|
46
|
+
keep_alive=self.es_keep_alive,
|
|
47
|
+
timeout=self.es_timeout_s,
|
|
48
|
+
max_retries=self.es_max_retries,
|
|
49
|
+
max_retry_wait_s=self.es_max_retry_wait_s,
|
|
50
|
+
api_key=self.ds_api_key,
|
|
51
|
+
)
|
|
52
|
+
client.transport._verified_elasticsearch = ( # pylint: disable=protected-access
|
|
53
|
+
True
|
|
54
|
+
)
|
|
55
|
+
return client
|
|
56
|
+
|
|
57
|
+
def to_task_client(self) -> "DatashareTaskClient":
|
|
58
|
+
from datashare_python.task_client import DatashareTaskClient
|
|
59
|
+
|
|
60
|
+
return DatashareTaskClient(self.ds_url)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
3
|
+
import pycountry
|
|
4
|
+
from icij_common.es import DOC_CONTENT, DOC_LANGUAGE, DOC_ROOT_ID, ID_, SOURCE
|
|
5
|
+
from icij_common.pydantic_utils import ICIJModel, LowerCamelCaseModel
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Document(LowerCamelCaseModel):
|
|
10
|
+
id: str
|
|
11
|
+
root_document: str
|
|
12
|
+
content: str
|
|
13
|
+
language: str
|
|
14
|
+
tags: list[str] = Field(default_factory=list)
|
|
15
|
+
content_translated: dict[str, str] = Field(
|
|
16
|
+
default_factory=dict, alias="content_translated"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_es(cls, es_doc: dict) -> Self:
|
|
21
|
+
sources = es_doc[SOURCE]
|
|
22
|
+
return cls(
|
|
23
|
+
id=es_doc[ID_],
|
|
24
|
+
content=sources[DOC_CONTENT],
|
|
25
|
+
content_translated=sources.get("content_translated", dict()),
|
|
26
|
+
language=sources[DOC_LANGUAGE],
|
|
27
|
+
root_document=sources[DOC_ROOT_ID],
|
|
28
|
+
tags=sources.get("tags", []),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ClassificationConfig(ICIJModel):
|
|
33
|
+
task: str = Field(const=True, default="text-classification")
|
|
34
|
+
model: str = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
|
|
35
|
+
batch_size: int = 16
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TranslationConfig(ICIJModel):
|
|
39
|
+
task: str = Field(const=True, default="translation")
|
|
40
|
+
model: str = "Helsinki-NLP/opus-mt"
|
|
41
|
+
batch_size: int = 16
|
|
42
|
+
|
|
43
|
+
def to_pipeline_args(self, source_language: str, *, target_language: str) -> dict:
|
|
44
|
+
as_dict = self.dict()
|
|
45
|
+
source_alpha2 = pycountry.languages.get(name=source_language).alpha_2
|
|
46
|
+
target_alpha2 = pycountry.languages.get(name=target_language).alpha_2
|
|
47
|
+
as_dict["task"] = f"translation_{source_alpha2}_to_{target_alpha2}"
|
|
48
|
+
as_dict["model"] = f"{self.model}-{source_alpha2}-{target_alpha2}"
|
|
49
|
+
return as_dict
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from icij_common.pydantic_utils import jsonable_encoder
|
|
5
|
+
from icij_worker import Task, TaskError, TaskState
|
|
6
|
+
from icij_worker.exceptions import UnknownTask
|
|
7
|
+
from icij_worker.utils.http import AiohttpClient
|
|
8
|
+
|
|
9
|
+
# TODO: maxRetries is not supported by java, it's automatically set to 3
|
|
10
|
+
_TASK_UNSUPPORTED = {"max_retries"}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatashareTaskClient(AiohttpClient):
|
|
14
|
+
def __init__(self, datashare_url: str, api_key: str | None = None) -> None:
|
|
15
|
+
headers = None
|
|
16
|
+
if api_key is not None:
|
|
17
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
18
|
+
super().__init__(datashare_url, headers=headers)
|
|
19
|
+
|
|
20
|
+
async def __aenter__(self):
|
|
21
|
+
await super().__aenter__()
|
|
22
|
+
if "Authorization" not in self._headers:
|
|
23
|
+
async with self._get("/settings") as res:
|
|
24
|
+
# SimpleCookie doesn't seem to parse DS cookie so we perform some dirty
|
|
25
|
+
# hack here
|
|
26
|
+
session_id = [
|
|
27
|
+
item
|
|
28
|
+
for item in res.headers["Set-Cookie"].split("; ")
|
|
29
|
+
if "session_id" in item
|
|
30
|
+
]
|
|
31
|
+
if len(session_id) != 1:
|
|
32
|
+
raise ValueError("Invalid cookie")
|
|
33
|
+
k, v = session_id[0].split("=")
|
|
34
|
+
self._session.cookie_jar.update_cookies({k: v})
|
|
35
|
+
|
|
36
|
+
async def create_task(
|
|
37
|
+
self,
|
|
38
|
+
name: str,
|
|
39
|
+
args: Dict[str, Any],
|
|
40
|
+
*,
|
|
41
|
+
id_: Optional[str] = None,
|
|
42
|
+
group: Optional[str] = None,
|
|
43
|
+
) -> str:
|
|
44
|
+
if id_ is None:
|
|
45
|
+
id_ = _generate_task_id(name)
|
|
46
|
+
task = Task.create(task_id=id_, task_name=name, args=args)
|
|
47
|
+
task = jsonable_encoder(task, exclude=_TASK_UNSUPPORTED, exclude_unset=True)
|
|
48
|
+
task.pop("createdAt")
|
|
49
|
+
url = f"/api/task/{id_}"
|
|
50
|
+
if group is not None:
|
|
51
|
+
if not isinstance(group, str):
|
|
52
|
+
raise TypeError(f"expected group to be a string found {group}")
|
|
53
|
+
url += f"?group={group}"
|
|
54
|
+
async with self._put(url, json=task) as res:
|
|
55
|
+
task_res = await res.json()
|
|
56
|
+
return task_res["taskId"]
|
|
57
|
+
|
|
58
|
+
async def get_task(self, id_: str) -> Task:
|
|
59
|
+
url = f"/api/task/{id_}"
|
|
60
|
+
async with self._get(url) as res:
|
|
61
|
+
task = await res.json()
|
|
62
|
+
if task is None:
|
|
63
|
+
raise UnknownTask(id_)
|
|
64
|
+
# TODO: align Java on Python here... it's not a good idea to store results
|
|
65
|
+
# inside tasks since result can be quite large and we may want to get the task
|
|
66
|
+
# metadata without having to deal with the large task results...
|
|
67
|
+
task = _ds_to_icij_worker_task(task)
|
|
68
|
+
task = Task(**task)
|
|
69
|
+
return task
|
|
70
|
+
|
|
71
|
+
async def get_tasks(self) -> list[Task]:
|
|
72
|
+
url = "/api/task/all"
|
|
73
|
+
async with self._get(url) as res:
|
|
74
|
+
tasks = await res.json()
|
|
75
|
+
# TODO: align Java on Python here... it's not a good idea to store results
|
|
76
|
+
# inside tasks since result can be quite large and we may want to get the task
|
|
77
|
+
# metadata without having to deal with the large task results...
|
|
78
|
+
tasks = (_ds_to_icij_worker_task(t) for t in tasks)
|
|
79
|
+
tasks = [Task(**task) for task in tasks]
|
|
80
|
+
return tasks
|
|
81
|
+
|
|
82
|
+
async def get_task_state(self, id_: str) -> TaskState:
|
|
83
|
+
return (await self.get_task(id_)).state
|
|
84
|
+
|
|
85
|
+
async def get_task_result(self, id_: str) -> Any:
|
|
86
|
+
url = f"/api/task/{id_}/results"
|
|
87
|
+
async with self._get(url) as res:
|
|
88
|
+
task_res = await res.json()
|
|
89
|
+
return task_res
|
|
90
|
+
|
|
91
|
+
async def get_task_error(self, id_: str) -> TaskError:
|
|
92
|
+
url = f"/api/task/{id_}"
|
|
93
|
+
async with self._get(url) as res:
|
|
94
|
+
task = await res.json()
|
|
95
|
+
if task is None:
|
|
96
|
+
raise UnknownTask(id_)
|
|
97
|
+
task_state = TaskState[task["state"]]
|
|
98
|
+
if task_state != TaskState.ERROR:
|
|
99
|
+
msg = f"can't find error for task {id_} in state {task_state}"
|
|
100
|
+
raise ValueError(msg)
|
|
101
|
+
error = TaskError(**task["error"])
|
|
102
|
+
return error
|
|
103
|
+
|
|
104
|
+
async def delete(self, id_: str):
|
|
105
|
+
url = f"/api/task/{id_}"
|
|
106
|
+
async with self._delete(url):
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
async def delete_all_tasks(self):
|
|
110
|
+
for t in await self.get_tasks():
|
|
111
|
+
await self.delete(t.id)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _generate_task_id(task_name: str) -> str:
|
|
115
|
+
return f"{task_name}-{uuid.uuid4()}"
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
_JAVA_TASK_ATTRIBUTES = ["result", "error"]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _ds_to_icij_worker_task(task: dict) -> dict:
|
|
122
|
+
for k in _JAVA_TASK_ATTRIBUTES:
|
|
123
|
+
task.pop(k, None)
|
|
124
|
+
return task
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import AsyncGenerator, Generator, Iterable, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from elasticsearch._async.helpers import async_bulk
|
|
6
|
+
from icij_common.es import (
|
|
7
|
+
BOOL,
|
|
8
|
+
DOC_CONTENT,
|
|
9
|
+
DOC_CONTENT_TRANSLATED,
|
|
10
|
+
DOC_LANGUAGE,
|
|
11
|
+
DOC_ROOT_ID,
|
|
12
|
+
ESClient,
|
|
13
|
+
HITS,
|
|
14
|
+
ID_,
|
|
15
|
+
MUST_NOT,
|
|
16
|
+
QUERY,
|
|
17
|
+
SHOULD,
|
|
18
|
+
TERM,
|
|
19
|
+
UPDATE,
|
|
20
|
+
and_query,
|
|
21
|
+
bulk_action,
|
|
22
|
+
has_id,
|
|
23
|
+
)
|
|
24
|
+
from icij_worker.ds_task_client import DatashareTaskClient
|
|
25
|
+
from icij_worker.typing_ import PercentProgress
|
|
26
|
+
from icij_worker.utils.progress import to_raw_progress, to_scaled_progress
|
|
27
|
+
from transformers import Pipeline, pipeline
|
|
28
|
+
|
|
29
|
+
from datashare_python.constants import PYTHON_TASK_GROUP
|
|
30
|
+
from datashare_python.objects import ClassificationConfig, Document
|
|
31
|
+
from datashare_python.tasks.dependencies import lifespan_es_client, lifespan_task_client
|
|
32
|
+
from datashare_python.utils import batches
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
async def create_classification_tasks(
|
|
38
|
+
*,
|
|
39
|
+
project: str,
|
|
40
|
+
language: str,
|
|
41
|
+
n_workers: int,
|
|
42
|
+
config: ClassificationConfig | None,
|
|
43
|
+
es_client: ESClient | None = None,
|
|
44
|
+
task_client: DatashareTaskClient | None = None,
|
|
45
|
+
progress: PercentProgress | None = None,
|
|
46
|
+
) -> list[str]:
|
|
47
|
+
if n_workers < 1:
|
|
48
|
+
raise ValueError("n_workers must be at least 1")
|
|
49
|
+
if es_client is None:
|
|
50
|
+
es_client = lifespan_es_client()
|
|
51
|
+
if task_client is None:
|
|
52
|
+
task_client = lifespan_task_client()
|
|
53
|
+
task_ids = []
|
|
54
|
+
if config is None:
|
|
55
|
+
config = ClassificationConfig()
|
|
56
|
+
# Retrieve unprocessed docs.
|
|
57
|
+
model = config.model
|
|
58
|
+
unclassified = _get_unclassified(
|
|
59
|
+
es_client, project=project, language=language, model=model
|
|
60
|
+
)
|
|
61
|
+
unclassified = [d[ID_] async for d in unclassified]
|
|
62
|
+
n_docs = len(unclassified)
|
|
63
|
+
if not n_docs:
|
|
64
|
+
logger.info("found not unclassified documents !")
|
|
65
|
+
return task_ids
|
|
66
|
+
logger.info("found %s unclassified documents !", n_docs)
|
|
67
|
+
fetch_unclassified_progress = 0.5
|
|
68
|
+
if progress is not None:
|
|
69
|
+
await progress(fetch_unclassified_progress)
|
|
70
|
+
# Roughly split the load between workers:
|
|
71
|
+
# - they should approximately receive the same amount of work
|
|
72
|
+
# - they should receive tasks which are long enough to avoid model loading overhead
|
|
73
|
+
# - task should be short enough to avoid starting all over again from scratch in
|
|
74
|
+
# case of failure
|
|
75
|
+
n_tasks = max(n_docs // n_workers, n_docs // (n_workers * 5), 1)
|
|
76
|
+
task_batch_size = n_docs // n_tasks
|
|
77
|
+
if progress is not None:
|
|
78
|
+
# We scale the progress to post incremental progress updates from 0 to n_tasks
|
|
79
|
+
progress = to_scaled_progress(progress, start=fetch_unclassified_progress)
|
|
80
|
+
progress = to_raw_progress(progress, max_progress=n_tasks)
|
|
81
|
+
logger.info("creating %s classification tasks...", n_tasks)
|
|
82
|
+
# We create classification tasks which will be picked up by the workers
|
|
83
|
+
args = {"project": project, "config": config.dict(), "language": language}
|
|
84
|
+
for batch in batches(unclassified, task_batch_size):
|
|
85
|
+
args["docs"] = batch
|
|
86
|
+
task_id = await task_client.create_task(
|
|
87
|
+
"classify_docs", args, group=PYTHON_TASK_GROUP.name
|
|
88
|
+
)
|
|
89
|
+
task_ids.append(task_id)
|
|
90
|
+
if progress is not None:
|
|
91
|
+
await progress(len(task_ids))
|
|
92
|
+
logger.info("created all classification tasks !")
|
|
93
|
+
return task_ids
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
_CLASSIF_DOC_SOURCES = [DOC_CONTENT, DOC_ROOT_ID, DOC_CONTENT_TRANSLATED, DOC_LANGUAGE]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def classify_docs(
|
|
100
|
+
docs: list[str],
|
|
101
|
+
*,
|
|
102
|
+
language: str,
|
|
103
|
+
project: str,
|
|
104
|
+
config: ClassificationConfig = ClassificationConfig(),
|
|
105
|
+
progress: PercentProgress | None = None,
|
|
106
|
+
es_client: ESClient | None = None,
|
|
107
|
+
) -> int:
|
|
108
|
+
if es_client is None:
|
|
109
|
+
es_client = lifespan_es_client()
|
|
110
|
+
n_docs = len(docs)
|
|
111
|
+
model = config.model
|
|
112
|
+
# Torch/macOS silicon stuff
|
|
113
|
+
device = None
|
|
114
|
+
if torch.backends.mps.is_available():
|
|
115
|
+
device = torch.device("mps")
|
|
116
|
+
# Load the classification pipeline
|
|
117
|
+
pipe = pipeline(config.task, model=model, device=device)
|
|
118
|
+
model = pipe.model.name_or_path
|
|
119
|
+
# Convert the progress to a "raw" progress to update the progress incrementally
|
|
120
|
+
# from 0 to n_docs (rather than 0.0 to 1.0)
|
|
121
|
+
progress = to_raw_progress(progress, max_progress=n_docs)
|
|
122
|
+
seen = 0
|
|
123
|
+
# We batch the data ourselves, ideally, we should use an async version of:
|
|
124
|
+
# https://huggingface.co/docs/datasets/v3.1.0/en/package_reference/main_classes#datasets.Dataset.from_generator
|
|
125
|
+
for batch in batches(docs, batch_size=config.batch_size):
|
|
126
|
+
batch_length = len(batch)
|
|
127
|
+
batch_docs = []
|
|
128
|
+
async for page in es_client.poll_search_pages(
|
|
129
|
+
body={QUERY: has_id(batch)},
|
|
130
|
+
_source_includes=_CLASSIF_DOC_SOURCES,
|
|
131
|
+
):
|
|
132
|
+
batch_docs.extend([Document.from_es(doc) for doc in page[HITS][HITS]])
|
|
133
|
+
contents = (_get_language_content(d, language) for d in batch_docs)
|
|
134
|
+
batch_docs, contents = zip(
|
|
135
|
+
*((d, c) for d, c in zip(batch_docs, contents) if c is not None)
|
|
136
|
+
)
|
|
137
|
+
batch_docs = tuple(batch_docs)
|
|
138
|
+
labels = _classify(pipe, list(contents))
|
|
139
|
+
# We add the classification results by updating the documents with new tags,
|
|
140
|
+
# this could also be done using: https://github.com/ICIJ/datashare-tarentula
|
|
141
|
+
await _add_classification_tags(
|
|
142
|
+
es_client, zip(batch_docs, labels), project, model=model
|
|
143
|
+
)
|
|
144
|
+
seen += batch_length
|
|
145
|
+
if progress is not None:
|
|
146
|
+
await progress(seen)
|
|
147
|
+
# Return the number of classified documents
|
|
148
|
+
return n_docs
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _classify(pipe: Pipeline, texts: list[str]) -> Generator[str, None, None]:
|
|
152
|
+
# In practice, we should chunk the text
|
|
153
|
+
for res in pipe(texts, padding=True, truncation=True):
|
|
154
|
+
yield res["label"]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _get_language_content(doc: Document, language: str) -> Optional[str]:
|
|
158
|
+
if doc.language == language:
|
|
159
|
+
return doc.content
|
|
160
|
+
return doc.content_translated.get(language)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
_SCRIPT_SOURCES = """
|
|
164
|
+
if( !ctx._source.containsKey("tags") ) {
|
|
165
|
+
ctx._source.tags = [];
|
|
166
|
+
}
|
|
167
|
+
if( !ctx._source.tags.contains(params.tag) ) {
|
|
168
|
+
ctx._source.tags.add(params.tag);
|
|
169
|
+
}
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
async def _add_classification_tags(
|
|
174
|
+
es_client: ESClient,
|
|
175
|
+
tags: Iterable[tuple[Document, str]],
|
|
176
|
+
project: str,
|
|
177
|
+
*,
|
|
178
|
+
model: str,
|
|
179
|
+
):
|
|
180
|
+
actions = (
|
|
181
|
+
bulk_action(
|
|
182
|
+
op_type=UPDATE,
|
|
183
|
+
index=project,
|
|
184
|
+
id_=doc.id,
|
|
185
|
+
routing=doc.root_document,
|
|
186
|
+
script={
|
|
187
|
+
"source": _SCRIPT_SOURCES,
|
|
188
|
+
"lang": "painless",
|
|
189
|
+
"params": {"tag": f"classified:{model}:{label}"},
|
|
190
|
+
},
|
|
191
|
+
)
|
|
192
|
+
for doc, label in tags
|
|
193
|
+
)
|
|
194
|
+
await async_bulk(es_client, actions, raise_on_error=True, refresh="wait_for")
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _unclassified_query(model: str, language: str):
|
|
198
|
+
queries = (
|
|
199
|
+
# Get documents which aren't tagged yet
|
|
200
|
+
{BOOL: {MUST_NOT: {"prefix": {"tags": {"value": f"classified:{model}:"}}}}},
|
|
201
|
+
# And which are either in the model language or are translated in the model
|
|
202
|
+
# language
|
|
203
|
+
{
|
|
204
|
+
BOOL: {
|
|
205
|
+
SHOULD: [
|
|
206
|
+
{"exists": {"field": f"{DOC_CONTENT_TRANSLATED}.{language}"}},
|
|
207
|
+
{TERM: {DOC_LANGUAGE: language}},
|
|
208
|
+
]
|
|
209
|
+
}
|
|
210
|
+
},
|
|
211
|
+
)
|
|
212
|
+
query = and_query(*queries)
|
|
213
|
+
return query
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
async def _get_unclassified(
|
|
217
|
+
es_client: ESClient, project: str, *, language: str, model: str, **kwargs
|
|
218
|
+
) -> AsyncGenerator[dict, None]:
|
|
219
|
+
async for res in es_client.poll_search_pages(
|
|
220
|
+
index=project,
|
|
221
|
+
body=_unclassified_query(model, language=language),
|
|
222
|
+
sort="_doc:asc",
|
|
223
|
+
_source=False,
|
|
224
|
+
**kwargs,
|
|
225
|
+
):
|
|
226
|
+
for hit in res[HITS][HITS]:
|
|
227
|
+
yield hit
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from icij_common.es import ESClient
|
|
4
|
+
from icij_worker import WorkerConfig
|
|
5
|
+
from icij_worker.ds_task_client import DatashareTaskClient
|
|
6
|
+
from icij_worker.utils.dependencies import DependencyInjectionError
|
|
7
|
+
|
|
8
|
+
from datashare_python.config import AppConfig
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# Lifespan dependencies consist in global variable which can be loaded in function
|
|
13
|
+
# calling lifespan_<dep_name>(), which returns the global variable.
|
|
14
|
+
# The variable itself is created and setup in <>_setup function and if needed
|
|
15
|
+
# torn down in the <>_teardown function.
|
|
16
|
+
# The setup and tear down functions are registered in the APP_LIFESPAN_DEPS list which
|
|
17
|
+
# is then passed to the AsyncApp when creating it. The app will take care of setup up
|
|
18
|
+
# and tearing down all dependencies in the list. Since a dep might depend on another
|
|
19
|
+
# one, the order in which they are registered is important.
|
|
20
|
+
# We hence start by registering the configuration, other deps are created from it.
|
|
21
|
+
|
|
22
|
+
_ASYNC_APP_CONFIG: AppConfig | None = None
|
|
23
|
+
_ES_CLIENT: ESClient | None = None
|
|
24
|
+
_TASK_CLIENT: DatashareTaskClient | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# App loading setup
|
|
28
|
+
def load_app_config(worker_config: WorkerConfig, **_):
|
|
29
|
+
global _ASYNC_APP_CONFIG
|
|
30
|
+
if worker_config.app_bootstrap_config_path is not None:
|
|
31
|
+
_ASYNC_APP_CONFIG = AppConfig.parse_file(
|
|
32
|
+
worker_config.app_bootstrap_config_path
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
_ASYNC_APP_CONFIG = AppConfig()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Returns the globally injected config
|
|
39
|
+
def lifespan_config() -> AppConfig:
|
|
40
|
+
if _ASYNC_APP_CONFIG is None:
|
|
41
|
+
raise DependencyInjectionError("config")
|
|
42
|
+
return _ASYNC_APP_CONFIG
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Loggers setup
|
|
46
|
+
def setup_loggers(worker_id: str, **_):
|
|
47
|
+
config = lifespan_config()
|
|
48
|
+
config.setup_loggers(worker_id=worker_id)
|
|
49
|
+
logger.info("worker loggers ready to log 💬")
|
|
50
|
+
logger.info("app config: %s", config.json(indent=2))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Elasticsearch client setup
|
|
54
|
+
async def es_client_setup(**_):
|
|
55
|
+
# pylint: disable=unnecessary-dunder-call
|
|
56
|
+
config = lifespan_config()
|
|
57
|
+
global _ES_CLIENT
|
|
58
|
+
_ES_CLIENT = config.to_es_client()
|
|
59
|
+
await _ES_CLIENT.__aenter__()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Elasticsearch client teardown
|
|
63
|
+
async def es_client_teardown(exc_type, exc_val, exc_tb):
|
|
64
|
+
# pylint: disable=unnecessary-dunder-call
|
|
65
|
+
await lifespan_es_client().__aexit__(exc_type, exc_val, exc_tb)
|
|
66
|
+
global _ES_CLIENT
|
|
67
|
+
_ES_CLIENT = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Returns the globally injected ES client
|
|
71
|
+
def lifespan_es_client() -> ESClient:
|
|
72
|
+
# pylint: disable=unnecessary-dunder-call
|
|
73
|
+
if _ES_CLIENT is None:
|
|
74
|
+
raise DependencyInjectionError("es client")
|
|
75
|
+
return _ES_CLIENT
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Task client setup
|
|
79
|
+
async def task_client_setup(**_):
|
|
80
|
+
# pylint: disable=unnecessary-dunder-call
|
|
81
|
+
config = lifespan_config()
|
|
82
|
+
global _TASK_CLIENT
|
|
83
|
+
_TASK_CLIENT = config.to_task_client()
|
|
84
|
+
await _TASK_CLIENT.__aenter__()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# Task client teardown
|
|
88
|
+
async def task_client_teardown(exc_type, exc_val, exc_tb):
|
|
89
|
+
# pylint: disable=unnecessary-dunder-call
|
|
90
|
+
await lifespan_task_client().__aexit__(exc_type, exc_val, exc_tb)
|
|
91
|
+
global _TASK_CLIENT
|
|
92
|
+
_TASK_CLIENT = None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Returns the globally injected task client
|
|
96
|
+
def lifespan_task_client() -> DatashareTaskClient:
|
|
97
|
+
# pylint: disable=unnecessary-dunder-call
|
|
98
|
+
if _TASK_CLIENT is None:
|
|
99
|
+
raise DependencyInjectionError("task client")
|
|
100
|
+
return _TASK_CLIENT
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Register all dependencies in the format of:
|
|
104
|
+
# (<logging helper>, <dep setup>, <dep teardown>)
|
|
105
|
+
APP_LIFESPAN_DEPS = [
|
|
106
|
+
("loading async app configuration", load_app_config, None),
|
|
107
|
+
("loggers", setup_loggers, None),
|
|
108
|
+
("elasticsearch client", es_client_setup, es_client_teardown),
|
|
109
|
+
("task client", task_client_setup, task_client_teardown),
|
|
110
|
+
]
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import AsyncGenerator, Generator, Iterable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from aiostream.stream import chain
|
|
7
|
+
from elasticsearch._async.helpers import async_bulk
|
|
8
|
+
from icij_common.es import (
|
|
9
|
+
BOOL,
|
|
10
|
+
COUNT,
|
|
11
|
+
DOC_CONTENT,
|
|
12
|
+
DOC_LANGUAGE,
|
|
13
|
+
DOC_ROOT_ID,
|
|
14
|
+
ESClient,
|
|
15
|
+
HITS,
|
|
16
|
+
ID_,
|
|
17
|
+
QUERY,
|
|
18
|
+
SOURCE,
|
|
19
|
+
TERM,
|
|
20
|
+
has_id,
|
|
21
|
+
must_not,
|
|
22
|
+
)
|
|
23
|
+
from icij_worker.ds_task_client import DatashareTaskClient
|
|
24
|
+
from icij_worker.typing_ import PercentProgress
|
|
25
|
+
from icij_worker.utils.progress import to_raw_progress
|
|
26
|
+
from transformers import Pipeline, pipeline
|
|
27
|
+
|
|
28
|
+
from datashare_python.constants import PYTHON_TASK_GROUP
|
|
29
|
+
from datashare_python.objects import Document, TranslationConfig
|
|
30
|
+
from datashare_python.tasks.dependencies import lifespan_es_client, lifespan_task_client
|
|
31
|
+
from datashare_python.utils import async_batches, batches, before_and_after, once
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def create_translation_tasks(
|
|
37
|
+
*,
|
|
38
|
+
project: str,
|
|
39
|
+
target_language: str,
|
|
40
|
+
config: TranslationConfig | None = None,
|
|
41
|
+
es_client: ESClient | None = None,
|
|
42
|
+
task_client: DatashareTaskClient | None = None,
|
|
43
|
+
) -> list[str]:
|
|
44
|
+
if es_client is None:
|
|
45
|
+
es_client = lifespan_es_client()
|
|
46
|
+
if task_client is None:
|
|
47
|
+
task_client = lifespan_task_client()
|
|
48
|
+
task_ids = []
|
|
49
|
+
if config is None:
|
|
50
|
+
config = TranslationConfig()
|
|
51
|
+
# Retrieve unprocessed docs.
|
|
52
|
+
docs_by_language = _untranslated_by_language(
|
|
53
|
+
es_client, project, target_language=target_language
|
|
54
|
+
)
|
|
55
|
+
args = {
|
|
56
|
+
"project": project,
|
|
57
|
+
"config": config.dict(),
|
|
58
|
+
"target_language": target_language,
|
|
59
|
+
}
|
|
60
|
+
# We could set this to a smarter value
|
|
61
|
+
task_batch_size = config.batch_size * 4
|
|
62
|
+
current_language = None
|
|
63
|
+
async for language_docs in docs_by_language:
|
|
64
|
+
async for batch in async_batches(language_docs, task_batch_size):
|
|
65
|
+
language = batch[0][SOURCE][DOC_LANGUAGE]
|
|
66
|
+
batch = [doc[ID_] for doc in batch]
|
|
67
|
+
if language != current_language:
|
|
68
|
+
logger.info("creating translation task for docs in %s", language)
|
|
69
|
+
args["docs"] = batch
|
|
70
|
+
task_id = await task_client.create_task(
|
|
71
|
+
"translate_docs", args, group=PYTHON_TASK_GROUP.name
|
|
72
|
+
)
|
|
73
|
+
task_ids.append(task_id)
|
|
74
|
+
logger.info("done creating %s translation tasks", len(task_ids))
|
|
75
|
+
return task_ids
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
_TRANSLATION_DOC_SOURCES = [DOC_CONTENT, DOC_ROOT_ID, DOC_LANGUAGE]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
async def translate_docs(
|
|
82
|
+
docs: list[str],
|
|
83
|
+
target_language: str,
|
|
84
|
+
*,
|
|
85
|
+
project: str,
|
|
86
|
+
es_client: ESClient | None = None,
|
|
87
|
+
progress: PercentProgress | None = None,
|
|
88
|
+
config: TranslationConfig = TranslationConfig(),
|
|
89
|
+
) -> int:
|
|
90
|
+
if es_client is None:
|
|
91
|
+
es_client = lifespan_es_client()
|
|
92
|
+
n_docs = len(docs)
|
|
93
|
+
if not n_docs:
|
|
94
|
+
return 0
|
|
95
|
+
# Torch/macOS silicon stuff
|
|
96
|
+
device = None
|
|
97
|
+
if torch.backends.mps.is_available():
|
|
98
|
+
device = torch.device("mps")
|
|
99
|
+
seen = 0
|
|
100
|
+
# Convert the progress to a "raw" progress to update the progress incrementally
|
|
101
|
+
# rather than setting the progress rate
|
|
102
|
+
progress = to_raw_progress(progress, max_progress=n_docs)
|
|
103
|
+
pipe = None
|
|
104
|
+
# We batch the data ourselves, ideally, we should use an async version of:
|
|
105
|
+
# https://huggingface.co/docs/datasets/v3.1.0/en/package_reference/main_classes#datasets.Dataset.from_generator
|
|
106
|
+
for batch in batches(docs, batch_size=config.batch_size):
|
|
107
|
+
batch_docs = []
|
|
108
|
+
async for page in es_client.poll_search_pages(
|
|
109
|
+
body={QUERY: has_id(batch)},
|
|
110
|
+
_source_includes=_TRANSLATION_DOC_SOURCES,
|
|
111
|
+
):
|
|
112
|
+
batch_docs.extend((Document.from_es(doc) for doc in page[HITS][HITS]))
|
|
113
|
+
if pipe is None:
|
|
114
|
+
source_language = batch_docs[0].language
|
|
115
|
+
kwargs = config.to_pipeline_args(
|
|
116
|
+
source_language, target_language=target_language
|
|
117
|
+
)
|
|
118
|
+
pipe = pipeline(device=device, **kwargs)
|
|
119
|
+
# Load the classification pipeline
|
|
120
|
+
contents = [d.content for d in batch_docs]
|
|
121
|
+
translations = _translate(pipe, contents)
|
|
122
|
+
await _add_translation(
|
|
123
|
+
es_client,
|
|
124
|
+
zip(batch_docs, translations),
|
|
125
|
+
project,
|
|
126
|
+
target_language=target_language,
|
|
127
|
+
)
|
|
128
|
+
seen += len(batch)
|
|
129
|
+
if progress is not None:
|
|
130
|
+
await progress(seen)
|
|
131
|
+
# Return the number of classified documents
|
|
132
|
+
return n_docs
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _translate(pipe: Pipeline, texts: list[str]) -> Generator[str, None, None]:
|
|
136
|
+
for res in pipe(texts):
|
|
137
|
+
yield res["translation_text"]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _has_language(doc: dict, language: str) -> bool:
|
|
141
|
+
return doc[SOURCE][DOC_LANGUAGE] == language
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
async def _untranslated_by_language(
|
|
145
|
+
es_client: ESClient, project: str, target_language: str
|
|
146
|
+
) -> AsyncGenerator[AsyncGenerator[list[str], None], None]:
|
|
147
|
+
docs = _get_untranslated(es_client, project, target_language=target_language)
|
|
148
|
+
while True:
|
|
149
|
+
try:
|
|
150
|
+
next_doc = await anext(aiter(docs))
|
|
151
|
+
except StopAsyncIteration:
|
|
152
|
+
return
|
|
153
|
+
current_language = next_doc[SOURCE][DOC_LANGUAGE]
|
|
154
|
+
language_docs, docs = before_and_after(
|
|
155
|
+
docs, partial(_has_language, language=current_language)
|
|
156
|
+
)
|
|
157
|
+
yield chain(once(next_doc), language_docs)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
_SCRIPT_SOURCES = """
|
|
161
|
+
if( !ctx._source.containsKey("content_translated") ) {
|
|
162
|
+
ctx._source.content_translated = new HashMap();
|
|
163
|
+
}
|
|
164
|
+
ctx._source.content_translated[params.language] = params.translation;
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
async def _add_translation(
|
|
169
|
+
es_client: ESClient,
|
|
170
|
+
translations: Iterable[tuple[Document, str]],
|
|
171
|
+
project: str,
|
|
172
|
+
*,
|
|
173
|
+
target_language: str,
|
|
174
|
+
):
|
|
175
|
+
actions = (
|
|
176
|
+
{
|
|
177
|
+
"_op_type": "update",
|
|
178
|
+
"_index": project,
|
|
179
|
+
"_routing": doc.root_document,
|
|
180
|
+
ID_: doc.id,
|
|
181
|
+
"script": {
|
|
182
|
+
"source": _SCRIPT_SOURCES,
|
|
183
|
+
"lang": "painless",
|
|
184
|
+
"params": {"language": target_language, "translation": translation},
|
|
185
|
+
},
|
|
186
|
+
}
|
|
187
|
+
for doc, translation in translations
|
|
188
|
+
)
|
|
189
|
+
await async_bulk(es_client, actions, raise_on_error=True, refresh="wait_for")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _untranslated_query(target_language: str):
|
|
193
|
+
query = {
|
|
194
|
+
"query": {
|
|
195
|
+
BOOL: must_not(
|
|
196
|
+
{"exists": {"field": f"content_translated.{target_language}"}},
|
|
197
|
+
{TERM: {DOC_LANGUAGE: target_language}},
|
|
198
|
+
)
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
return query
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
async def _get_untranslated(
|
|
205
|
+
es_client: ESClient, project: str, *, target_language: str
|
|
206
|
+
) -> AsyncGenerator[dict, None]:
|
|
207
|
+
async for res in es_client.poll_search_pages(
|
|
208
|
+
index=project,
|
|
209
|
+
body=_untranslated_query(target_language),
|
|
210
|
+
_source_includes=[DOC_LANGUAGE],
|
|
211
|
+
sort=[f"{DOC_LANGUAGE}:asc", "_doc:asc"],
|
|
212
|
+
):
|
|
213
|
+
for hit in res[HITS][HITS]:
|
|
214
|
+
yield hit
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
async def _count_untranslated(
|
|
218
|
+
es_client: ESClient, project: str, *, target_language: str
|
|
219
|
+
) -> int:
|
|
220
|
+
res = await es_client.count(
|
|
221
|
+
index=project, body=_untranslated_query(target_language)
|
|
222
|
+
)
|
|
223
|
+
return res[COUNT]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
from itertools import islice
|
|
4
|
+
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, TypeVar
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T")
|
|
7
|
+
|
|
8
|
+
Predicate = Callable[[T], bool] | Callable[[T], Awaitable[bool]]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def async_batches(
|
|
12
|
+
iterable: AsyncIterable[T], batch_size: int
|
|
13
|
+
) -> AsyncIterator[tuple[T]]:
|
|
14
|
+
it = aiter(iterable)
|
|
15
|
+
if batch_size < 1:
|
|
16
|
+
raise ValueError("n must be at least one")
|
|
17
|
+
while True:
|
|
18
|
+
batch = []
|
|
19
|
+
while len(batch) < batch_size:
|
|
20
|
+
try:
|
|
21
|
+
batch.append(await anext(it))
|
|
22
|
+
except StopAsyncIteration:
|
|
23
|
+
if batch:
|
|
24
|
+
yield tuple(batch)
|
|
25
|
+
return
|
|
26
|
+
yield tuple(batch)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def batches(iterable: Iterable[T], batch_size: int):
|
|
30
|
+
if batch_size < 1:
|
|
31
|
+
raise ValueError("n must be at least one")
|
|
32
|
+
it = iter(iterable)
|
|
33
|
+
while batch := tuple(islice(it, batch_size)):
|
|
34
|
+
yield batch
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
async def maybe_await(maybe_awaitable: Awaitable[T] | T) -> T:
|
|
38
|
+
if inspect.isawaitable(maybe_awaitable):
|
|
39
|
+
return await maybe_awaitable
|
|
40
|
+
return maybe_awaitable
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
async def once(item: T) -> AsyncIterator[T]:
|
|
44
|
+
yield item
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def before_and_after(
|
|
48
|
+
iterable: AsyncIterable[T], predicate: Predicate[T]
|
|
49
|
+
) -> tuple[AsyncIterable[T], AsyncIterable[T]]:
|
|
50
|
+
transition = asyncio.get_event_loop().create_future()
|
|
51
|
+
|
|
52
|
+
async def true_iterator():
|
|
53
|
+
async for elem in iterable:
|
|
54
|
+
if await maybe_await(predicate(elem)):
|
|
55
|
+
yield elem
|
|
56
|
+
else:
|
|
57
|
+
transition.set_result(elem)
|
|
58
|
+
return
|
|
59
|
+
transition.set_exception(StopAsyncIteration)
|
|
60
|
+
|
|
61
|
+
async def remainder_iterator():
|
|
62
|
+
try:
|
|
63
|
+
yield await transition
|
|
64
|
+
except StopAsyncIteration:
|
|
65
|
+
return
|
|
66
|
+
async for elm in iterable:
|
|
67
|
+
yield elm
|
|
68
|
+
|
|
69
|
+
return true_iterator(), remainder_iterator()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: datashare-python
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Implement Datashare task in Python
|
|
5
|
+
Author-Email: =?utf-8?q?Cl=C3=A9ment_Doumouro?= <cdoumouro@icij.org>, =?utf-8?q?Cl=C3=A9ment_Doumouro?= <clement.doumouro@gmail.com>
|
|
6
|
+
Requires-Python: ~=3.11
|
|
7
|
+
Requires-Dist: aiostream~=0.6.4
|
|
8
|
+
Requires-Dist: aiohttp~=3.11.9
|
|
9
|
+
Requires-Dist: icij-common[elasticsearch]~=0.5.5
|
|
10
|
+
Requires-Dist: icij-worker[amqp]~=0.12
|
|
11
|
+
Requires-Dist: torch==2.6.0.dev20241101; sys_platform != "darwin"
|
|
12
|
+
Requires-Dist: torch!=2.6.0.dev20241101+cpu,<=2.6.0.dev20241101; sys_platform == "darwin"
|
|
13
|
+
Requires-Dist: transformers~=4.46.3
|
|
14
|
+
Requires-Dist: pycountry>=24.6.1
|
|
15
|
+
Requires-Dist: sentencepiece>=0.2.0
|
|
16
|
+
Requires-Dist: typer>=0.13.1
|
|
17
|
+
Requires-Dist: alive-progress>=3.2.0
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
|
|
20
|
+
<div style="background-image: linear-gradient(45deg, #193d87, #fa4070);">
|
|
21
|
+
<br/>
|
|
22
|
+
<p align="center">
|
|
23
|
+
<a href="https://datashare.icij.org/">
|
|
24
|
+
<img align="center" src="docs/assets/datashare-logo.svg" alt="Datashare" style="max-width: 60%">
|
|
25
|
+
</a>
|
|
26
|
+
</p>
|
|
27
|
+
<p align="center">
|
|
28
|
+
<em>Better analyze information, in all its forms</em>
|
|
29
|
+
</p>
|
|
30
|
+
<br/>
|
|
31
|
+
</div>
|
|
32
|
+
<br/>
|
|
33
|
+
|
|
34
|
+
---
|
|
35
|
+
|
|
36
|
+
**Documentation**: <a href="https://icij.github.io/datashare-python" target="_blank">https://icij.github.io/datashare-python</a>
|
|
37
|
+
|
|
38
|
+
---
|
|
39
|
+
|
|
40
|
+
# Implement **your own Datashare tasks**, written in Python
|
|
41
|
+
|
|
42
|
+
Most AI, Machine Learning, Data Engineering happens in Python.
|
|
43
|
+
[Datashare](https://icij.gitbook.io/datashare) now lets you extend its backend with your own tasks implemented in Python.
|
|
44
|
+
|
|
45
|
+
Turning your own ML pipelines into Datashare tasks is **very simple**, learn about it inside [documentation](https://icij.github.io/datashare-python).
|
|
46
|
+
|
|
47
|
+
Turning your own ML pipelines into Datashare tasks is **very simple**.
|
|
48
|
+
|
|
49
|
+
Actually, it's *almost* as simple as cloning our [template repo](https://github.com/ICIJ/datashare-python):
|
|
50
|
+
|
|
51
|
+
```
|
|
52
|
+
$ git clone git@github.com:ICIJ/datashare-python.git
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
replacing existing [app](https://github.com/ICIJ/datashare-python/blob/main/datashare_python/app.py) tasks with your own:
|
|
56
|
+
```python
|
|
57
|
+
from icij_worker import AsyncApp
|
|
58
|
+
|
|
59
|
+
app = AsyncApp("app")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@app.task
|
|
63
|
+
def hello_world() -> str:
|
|
64
|
+
return "Hello world"
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
installing [`uv`](https://docs.astral.sh/uv/) to set up dependencies and running your async Datashare worker:
|
|
68
|
+
```console
|
|
69
|
+
$ cd datashare-python
|
|
70
|
+
$ curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
71
|
+
$ uv run ./scripts/worker_entrypoint.sh
|
|
72
|
+
[INFO][icij_worker.backend.backend]: Loading worker configuration from env...
|
|
73
|
+
...
|
|
74
|
+
}
|
|
75
|
+
[INFO][icij_worker.backend.mp]: starting 1 worker for app datashare_python.app.app
|
|
76
|
+
...
|
|
77
|
+
```
|
|
78
|
+
you'll then be able to execute task by starting using our [HTTP client]() (and soon using Datashare's UI).
|
|
79
|
+
|
|
80
|
+
## Learn more reading our [documentation](https://icij.github.io/datashare-python) !
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
datashare_python-0.1.0.dist-info/METADATA,sha256=xt9hC3iIosOo3sEeePvELBAIYx-yg7b80wB9AewYlzs,2796
|
|
2
|
+
datashare_python-0.1.0.dist-info/WHEEL,sha256=thaaA2w1JzcGC48WYufAs8nrYZjJm8LqNfnXFOFyCC4,90
|
|
3
|
+
datashare_python-0.1.0.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
|
|
4
|
+
datashare_python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
datashare_python/__main__.py,sha256=agVB_DGJFNXl4AVVwb14i6fF8AXzyAXp8ykOGRPtB3A,83
|
|
6
|
+
datashare_python/app.py,sha256=QI8fVR6s6YU7S45s3hTGM7DqdISA4T7XnXU4UYW9-0E,2582
|
|
7
|
+
datashare_python/cli/__init__.py,sha256=wcNr-Tp5YmJCYAx4rBTq5EMaoNcPB-9On8LDcopg5v4,818
|
|
8
|
+
datashare_python/cli/tasks.py,sha256=rxYlDs0WWe1-bBtGshw4TxZ14u-uaeOTsx4DhZX0Ayo,5870
|
|
9
|
+
datashare_python/cli/utils.py,sha256=vykjfBBYW5hZVWn7eR3YyBb5NcV35mO7k_wIB5UNJiA,772
|
|
10
|
+
datashare_python/config.py,sha256=vsNExCz8J8mBUm7y-6FxGKzRVh5ULnlu3k9015lBXOI,1813
|
|
11
|
+
datashare_python/constants.py,sha256=naesOTqywwIN5IIZH3GLBcst7qLOvcQZLwxeZZ_IBI0,161
|
|
12
|
+
datashare_python/objects.py,sha256=ThFhy0XuVUYx_A47wJmFy3eID1kbSDcoxTiRbK3RrEo,1708
|
|
13
|
+
datashare_python/task_client.py,sha256=sVKFg7k3q7JjMEOObcdo4f7vCF72kbhPq5A4lGbu2qA,4528
|
|
14
|
+
datashare_python/tasks/__init__.py,sha256=xhlmAx5SMJ14kJQEZj8jpQU7_1lnXYhFhr7xpChtclQ,139
|
|
15
|
+
datashare_python/tasks/classify_docs.py,sha256=I1y6SyrRb_zcppUu_RmhepXjg00OlLbJ75Exw2BYNls,7764
|
|
16
|
+
datashare_python/tasks/dependencies.py,sha256=Niua0YrqIVgXTIfEghwaJmRia1XDGkiBVl__Hge8fFk,3661
|
|
17
|
+
datashare_python/tasks/translate_docs.py,sha256=Qu-b1qZUj18PTI-rsr3oSCGTHLYhNzf0ip9UC7G-CJY,7142
|
|
18
|
+
datashare_python/utils.py,sha256=y-ZoHt-1v-18o7EoCXD2fMYUxwrz7mZLgbzbggkm_Vc,1901
|
|
19
|
+
datashare_python-0.1.0.dist-info/RECORD,,
|