datashare-python 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. datashare_python-0.1.0/PKG-INFO +80 -0
  2. datashare_python-0.1.0/README.md +61 -0
  3. datashare_python-0.1.0/datashare_python/__init__.py +0 -0
  4. datashare_python-0.1.0/datashare_python/__main__.py +4 -0
  5. datashare_python-0.1.0/datashare_python/app.py +85 -0
  6. datashare_python-0.1.0/datashare_python/cli/__init__.py +30 -0
  7. datashare_python-0.1.0/datashare_python/cli/tasks.py +182 -0
  8. datashare_python-0.1.0/datashare_python/cli/utils.py +33 -0
  9. datashare_python-0.1.0/datashare_python/config.py +60 -0
  10. datashare_python-0.1.0/datashare_python/constants.py +6 -0
  11. datashare_python-0.1.0/datashare_python/objects.py +49 -0
  12. datashare_python-0.1.0/datashare_python/task_client.py +124 -0
  13. datashare_python-0.1.0/datashare_python/tasks/__init__.py +2 -0
  14. datashare_python-0.1.0/datashare_python/tasks/classify_docs.py +227 -0
  15. datashare_python-0.1.0/datashare_python/tasks/dependencies.py +110 -0
  16. datashare_python-0.1.0/datashare_python/tasks/translate_docs.py +223 -0
  17. datashare_python-0.1.0/datashare_python/utils.py +69 -0
  18. datashare_python-0.1.0/pyproject.toml +72 -0
  19. datashare_python-0.1.0/tests/__init__.py +0 -0
  20. datashare_python-0.1.0/tests/cli/__init__.py +0 -0
  21. datashare_python-0.1.0/tests/cli/test_tasks.py +193 -0
  22. datashare_python-0.1.0/tests/conftest.py +281 -0
  23. datashare_python-0.1.0/tests/tasks/test_translate_docs.py +37 -0
  24. datashare_python-0.1.0/tests/test_task_client.py +196 -0
  25. datashare_python-0.1.0/tests/test_tasks.py +181 -0
  26. datashare_python-0.1.0/tests/test_utils.py +31 -0
@@ -0,0 +1,80 @@
1
+ Metadata-Version: 2.1
2
+ Name: datashare-python
3
+ Version: 0.1.0
4
+ Summary: Implement Datashare task in Python
5
+ Author-Email: =?utf-8?q?Cl=C3=A9ment_Doumouro?= <cdoumouro@icij.org>, =?utf-8?q?Cl=C3=A9ment_Doumouro?= <clement.doumouro@gmail.com>
6
+ Requires-Python: ~=3.11
7
+ Requires-Dist: aiostream~=0.6.4
8
+ Requires-Dist: aiohttp~=3.11.9
9
+ Requires-Dist: icij-common[elasticsearch]~=0.5.5
10
+ Requires-Dist: icij-worker[amqp]~=0.12
11
+ Requires-Dist: torch==2.6.0.dev20241101; sys_platform != "darwin"
12
+ Requires-Dist: torch!=2.6.0.dev20241101+cpu,<=2.6.0.dev20241101; sys_platform == "darwin"
13
+ Requires-Dist: transformers~=4.46.3
14
+ Requires-Dist: pycountry>=24.6.1
15
+ Requires-Dist: sentencepiece>=0.2.0
16
+ Requires-Dist: typer>=0.13.1
17
+ Requires-Dist: alive-progress>=3.2.0
18
+ Description-Content-Type: text/markdown
19
+
20
+ <div style="background-image: linear-gradient(45deg, #193d87, #fa4070);">
21
+ <br/>
22
+ <p align="center">
23
+ <a href="https://datashare.icij.org/">
24
+ <img align="center" src="docs/assets/datashare-logo.svg" alt="Datashare" style="max-width: 60%">
25
+ </a>
26
+ </p>
27
+ <p align="center">
28
+ <em>Better analyze information, in all its forms</em>
29
+ </p>
30
+ <br/>
31
+ </div>
32
+ <br/>
33
+
34
+ ---
35
+
36
+ **Documentation**: <a href="https://icij.github.io/datashare-python" target="_blank">https://icij.github.io/datashare-python</a>
37
+
38
+ ---
39
+
40
+ # Implement **your own Datashare tasks**, written in Python
41
+
42
+ Most AI, Machine Learning, Data Engineering happens in Python.
43
+ [Datashare](https://icij.gitbook.io/datashare) now lets you extend its backend with your own tasks implemented in Python.
44
+
45
+ Turning your own ML pipelines into Datashare tasks is **very simple**, learn about it inside [documentation](https://icij.github.io/datashare-python).
46
+
47
+ Turning your own ML pipelines into Datashare tasks is **very simple**.
48
+
49
+ Actually, it's *almost* as simple as cloning our [template repo](https://github.com/ICIJ/datashare-python):
50
+
51
+ ```
52
+ $ git clone git@github.com:ICIJ/datashare-python.git
53
+ ```
54
+
55
+ replacing existing [app](https://github.com/ICIJ/datashare-python/blob/main/datashare_python/app.py) tasks with your own:
56
+ ```python
57
+ from icij_worker import AsyncApp
58
+
59
+ app = AsyncApp("app")
60
+
61
+
62
+ @app.task
63
+ def hello_world() -> str:
64
+ return "Hello world"
65
+ ```
66
+
67
+ installing [`uv`](https://docs.astral.sh/uv/) to set up dependencies and running your async Datashare worker:
68
+ ```console
69
+ $ cd datashare-python
70
+ $ curl -LsSf https://astral.sh/uv/install.sh | sh
71
+ $ uv run ./scripts/worker_entrypoint.sh
72
+ [INFO][icij_worker.backend.backend]: Loading worker configuration from env...
73
+ ...
74
+ }
75
+ [INFO][icij_worker.backend.mp]: starting 1 worker for app datashare_python.app.app
76
+ ...
77
+ ```
78
+ you'll then be able to execute task by starting using our [HTTP client]() (and soon using Datashare's UI).
79
+
80
+ ## Learn more reading our [documentation](https://icij.github.io/datashare-python) !
@@ -0,0 +1,61 @@
1
+ <div style="background-image: linear-gradient(45deg, #193d87, #fa4070);">
2
+ <br/>
3
+ <p align="center">
4
+ <a href="https://datashare.icij.org/">
5
+ <img align="center" src="docs/assets/datashare-logo.svg" alt="Datashare" style="max-width: 60%">
6
+ </a>
7
+ </p>
8
+ <p align="center">
9
+ <em>Better analyze information, in all its forms</em>
10
+ </p>
11
+ <br/>
12
+ </div>
13
+ <br/>
14
+
15
+ ---
16
+
17
+ **Documentation**: <a href="https://icij.github.io/datashare-python" target="_blank">https://icij.github.io/datashare-python</a>
18
+
19
+ ---
20
+
21
+ # Implement **your own Datashare tasks**, written in Python
22
+
23
+ Most AI, Machine Learning, Data Engineering happens in Python.
24
+ [Datashare](https://icij.gitbook.io/datashare) now lets you extend its backend with your own tasks implemented in Python.
25
+
26
+ Turning your own ML pipelines into Datashare tasks is **very simple**, learn about it inside [documentation](https://icij.github.io/datashare-python).
27
+
28
+ Turning your own ML pipelines into Datashare tasks is **very simple**.
29
+
30
+ Actually, it's *almost* as simple as cloning our [template repo](https://github.com/ICIJ/datashare-python):
31
+
32
+ ```
33
+ $ git clone git@github.com:ICIJ/datashare-python.git
34
+ ```
35
+
36
+ replacing existing [app](https://github.com/ICIJ/datashare-python/blob/main/datashare_python/app.py) tasks with your own:
37
+ ```python
38
+ from icij_worker import AsyncApp
39
+
40
+ app = AsyncApp("app")
41
+
42
+
43
+ @app.task
44
+ def hello_world() -> str:
45
+ return "Hello world"
46
+ ```
47
+
48
+ installing [`uv`](https://docs.astral.sh/uv/) to set up dependencies and running your async Datashare worker:
49
+ ```console
50
+ $ cd datashare-python
51
+ $ curl -LsSf https://astral.sh/uv/install.sh | sh
52
+ $ uv run ./scripts/worker_entrypoint.sh
53
+ [INFO][icij_worker.backend.backend]: Loading worker configuration from env...
54
+ ...
55
+ }
56
+ [INFO][icij_worker.backend.mp]: starting 1 worker for app datashare_python.app.app
57
+ ...
58
+ ```
59
+ you'll then be able to execute task by starting using our [HTTP client]() (and soon using Datashare's UI).
60
+
61
+ ## Learn more reading our [documentation](https://icij.github.io/datashare-python) !
File without changes
@@ -0,0 +1,4 @@
1
+ from datashare_python.cli import cli_app
2
+
3
+ if __name__ == "__main__":
4
+ cli_app()
@@ -0,0 +1,85 @@
1
+ from typing import Optional
2
+
3
+ from icij_worker import AsyncApp
4
+ from icij_worker.typing_ import PercentProgress
5
+ from pydantic import parse_obj_as
6
+
7
+ from datashare_python.constants import PYTHON_TASK_GROUP
8
+ from datashare_python.objects import ClassificationConfig, TranslationConfig
9
+ from datashare_python.tasks import (
10
+ classify_docs as classify_docs_,
11
+ create_classification_tasks as create_classification_tasks_,
12
+ create_translation_tasks as create_translation_tasks_,
13
+ translate_docs as translate_docs_,
14
+ )
15
+ from datashare_python.tasks.dependencies import APP_LIFESPAN_DEPS
16
+
17
+ app = AsyncApp("ml", dependencies=APP_LIFESPAN_DEPS)
18
+
19
+
20
+ @app.task(group=PYTHON_TASK_GROUP)
21
+ async def create_translation_tasks(
22
+ project: str,
23
+ target_language: str,
24
+ config: dict | None = None,
25
+ user: dict | None = None, # pylint: disable=unused-argument
26
+ ) -> list[str]:
27
+ # Parse the incoming config
28
+ config = parse_obj_as(Optional[TranslationConfig], config)
29
+ return await create_translation_tasks_(
30
+ project=project, target_language=target_language, config=config
31
+ )
32
+
33
+
34
+ @app.task(group=PYTHON_TASK_GROUP)
35
+ async def translate_docs(
36
+ docs: list[str],
37
+ project: str,
38
+ target_language: str,
39
+ progress: PercentProgress,
40
+ config: dict | None = None,
41
+ user: dict | None = None, # pylint: disable=unused-argument
42
+ ) -> int:
43
+ config = parse_obj_as(Optional[TranslationConfig], config)
44
+ return await translate_docs_(
45
+ docs, target_language, project=project, config=config, progress=progress
46
+ )
47
+
48
+
49
+ @app.task(group=PYTHON_TASK_GROUP)
50
+ async def create_classification_tasks(
51
+ project: str,
52
+ language: str,
53
+ n_workers: int,
54
+ progress: PercentProgress,
55
+ config: dict | None = None,
56
+ user: dict | None = None, # pylint: disable=unused-argument
57
+ ) -> list[str]:
58
+ config = parse_obj_as(Optional[ClassificationConfig], config)
59
+ return await create_classification_tasks_(
60
+ project=project,
61
+ language=language,
62
+ n_workers=n_workers,
63
+ config=config,
64
+ progress=progress,
65
+ )
66
+
67
+
68
+ @app.task(group=PYTHON_TASK_GROUP)
69
+ async def classify_docs(
70
+ docs: list[str],
71
+ language: str,
72
+ project: str,
73
+ progress: PercentProgress,
74
+ config: dict | None = None,
75
+ user: dict | None = None, # pylint: disable=unused-argument
76
+ ) -> int:
77
+ config = parse_obj_as(Optional[ClassificationConfig], config)
78
+ return await classify_docs_(
79
+ docs, language=language, project=project, config=config, progress=progress
80
+ )
81
+
82
+
83
+ @app.task(group=PYTHON_TASK_GROUP)
84
+ def ping() -> str:
85
+ return "pong"
@@ -0,0 +1,30 @@
1
+ import importlib.metadata
2
+ from typing import Annotated, Optional
3
+
4
+ import typer
5
+
6
+ import datashare_python
7
+ from datashare_python.cli.tasks import task_app
8
+ from datashare_python.cli.utils import AsyncTyper
9
+
10
+ cli_app = AsyncTyper(context_settings={"help_option_names": ["-h", "--help"]})
11
+ cli_app.add_typer(task_app)
12
+
13
+
14
+ def version_callback(value: bool):
15
+ if value:
16
+ package_version = importlib.metadata.version(datashare_python.__name__)
17
+ print(package_version)
18
+ raise typer.Exit()
19
+
20
+
21
+ @cli_app.callback(name="datashare-python")
22
+ def main(
23
+ version: Annotated[ # pylint: disable=unused-argument
24
+ Optional[bool],
25
+ typer.Option( # pylint: disable=unused-argument
26
+ "--version", callback=version_callback, is_eager=True
27
+ ),
28
+ ] = None
29
+ ):
30
+ """Datashare Python CLI"""
@@ -0,0 +1,182 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+ from traceback import FrameSummary, StackSummary
7
+ from typing import Annotated, Any, Optional
8
+
9
+ import typer
10
+ from alive_progress import alive_bar
11
+ from icij_worker import TaskState
12
+ from icij_worker.objects import READY_STATES, Task, TaskError
13
+
14
+ from datashare_python.cli.utils import AsyncTyper, eprint
15
+ from datashare_python.constants import PYTHON_TASK_GROUP
16
+ from datashare_python.task_client import DatashareTaskClient
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ DEFAULT_DS_ADDRESS = "http://localhost:8080"
21
+
22
+ _ARGS_HELP = "task argument as a JSON string or file path"
23
+ _GROUP_HELP = "task group"
24
+ _DS_API_KEY_HELP = "datashare API key"
25
+ _DS_URL_HELP = "datashare address"
26
+ _POLLING_INTERVAL_S_HELP = "task state polling interval in seconds"
27
+ _NAME_HELP = "registered task name"
28
+ _RESULT_HELP = "get a task result"
29
+ _START_HELP = "creates a new task and start it"
30
+ _TASK_ID_HELP = "task ID"
31
+ _WATCH_HELP = "watch a task until it's complete"
32
+
33
+ TaskArgs = str
34
+
35
+ task_app = AsyncTyper(name="task")
36
+
37
+
38
+ @task_app.async_command(help=_START_HELP)
39
+ async def start(
40
+ name: Annotated[str, typer.Argument(help=_NAME_HELP)],
41
+ args: Annotated[TaskArgs, typer.Argument(help=_ARGS_HELP)] = None,
42
+ group: Annotated[
43
+ Optional[str],
44
+ typer.Option("--group", "-g", help=_GROUP_HELP),
45
+ ] = PYTHON_TASK_GROUP.name,
46
+ ds_address: Annotated[
47
+ str, typer.Option("--ds-address", "-a", help=_DS_URL_HELP)
48
+ ] = DEFAULT_DS_ADDRESS,
49
+ ds_api_key: Annotated[
50
+ Optional[str], typer.Option("--ds-api-key", "-k", help=_DS_API_KEY_HELP)
51
+ ] = None,
52
+ ):
53
+ match args:
54
+ case str():
55
+ as_path = Path(name)
56
+ if as_path.exists():
57
+ args = json.loads(as_path.read_text())
58
+ else:
59
+ args = json.loads(args)
60
+ case None:
61
+ args = dict()
62
+ case _:
63
+ raise TypeError(f"Invalid args {args}")
64
+ client = DatashareTaskClient(ds_address, api_key=ds_api_key)
65
+ async with client:
66
+ task_id = await client.create_task(name, args, group=group)
67
+ eprint(f"Task({task_id}) started !")
68
+ eprint(f"Task({task_id}) 🛫")
69
+ print(task_id)
70
+
71
+
72
+ @task_app.async_command(help=_WATCH_HELP)
73
+ async def watch(
74
+ task_id: Annotated[str, typer.Argument(help=_TASK_ID_HELP)],
75
+ ds_address: Annotated[
76
+ str, typer.Option("--ds-address", "-a", help=_DS_URL_HELP)
77
+ ] = DEFAULT_DS_ADDRESS,
78
+ ds_api_key: Annotated[
79
+ Optional[str], typer.Option("--ds-api-key", "-k", help=_DS_API_KEY_HELP)
80
+ ] = None,
81
+ polling_interval_s: Annotated[
82
+ float, typer.Option("--polling-interval-s", "-p", help=_POLLING_INTERVAL_S_HELP)
83
+ ] = 1.0,
84
+ ):
85
+ client = DatashareTaskClient(ds_address, api_key=ds_api_key)
86
+ async with client:
87
+ task = await client.get_task(task_id)
88
+ if task.state is READY_STATES:
89
+ await _handle_ready(task, client, already_done=True)
90
+ await _handle_alive(task, client, polling_interval_s)
91
+ print(task_id)
92
+
93
+
94
+ @task_app.async_command(help=_RESULT_HELP)
95
+ async def result(
96
+ task_id: Annotated[str, typer.Argument(help=_TASK_ID_HELP)],
97
+ ds_address: Annotated[
98
+ str, typer.Option("--ds-address", "-a", help=_DS_URL_HELP)
99
+ ] = DEFAULT_DS_ADDRESS,
100
+ ds_api_key: Annotated[
101
+ Optional[str], typer.Option("--ds-api-key", "-k", help=_DS_API_KEY_HELP)
102
+ ] = None,
103
+ ) -> Any:
104
+ client = DatashareTaskClient(ds_address, api_key=ds_api_key)
105
+ async with client:
106
+ res = await client.get_task_result(task_id)
107
+ if isinstance(res, (dict, list)):
108
+ res = json.dumps(res, indent=2)
109
+ print(res)
110
+
111
+
112
+ async def _handle_ready(
113
+ task: Task, client: DatashareTaskClient, already_done: bool = False
114
+ ) -> None:
115
+ match task.state:
116
+ case TaskState.ERROR:
117
+ await _handle_error(task, client)
118
+ case TaskState.CANCELLED:
119
+ await _handle_cancelled(task)
120
+ case TaskState.DONE:
121
+ if already_done:
122
+ await _handle_already_done(task)
123
+ else:
124
+ await _handle_done(task)
125
+ case _:
126
+ raise ValueError(f"Unexpected task state {task.state}")
127
+
128
+
129
+ async def _handle_error(task, client: DatashareTaskClient):
130
+ error = await client.get_task_error(task.id)
131
+ eprint(
132
+ f"Task({task.id}) failed with the following"
133
+ f" error:\n\n{_format_error(error)}"
134
+ )
135
+ eprint(f"Task({task.id}) ❌")
136
+ raise typer.Exit(code=1)
137
+
138
+
139
+ async def _handle_cancelled(task):
140
+ eprint(f"Task({task.id}) was cancelled !")
141
+ eprint(f"Task({task.id}) 🛑")
142
+ raise typer.Exit(code=1)
143
+
144
+
145
+ async def _handle_already_done(task):
146
+ eprint(f"Task({task.id}) ✅ is already completed !")
147
+
148
+
149
+ async def _handle_done(task):
150
+ eprint(f"Task({task.id}) 🛬")
151
+ eprint(f"Task({task.id}) ✅")
152
+
153
+
154
+ async def _handle_alive(
155
+ task: Task, client: DatashareTaskClient, polling_interval_s: float
156
+ ) -> None:
157
+ title = f"Task({task.id}) 🛫"
158
+ stats = "(ETA: {eta})"
159
+ monitor = "{percent}"
160
+ progress_bar = alive_bar(
161
+ title=title, manual=True, stats=stats, monitor=monitor, file=sys.stderr
162
+ )
163
+ with progress_bar as bar:
164
+ task_state = task.state
165
+ while task_state not in READY_STATES:
166
+ task = await client.get_task(task.id)
167
+ task_state = task.state
168
+ progress = task.progress or 0.0
169
+ bar(progress) # pylint: disable=not-callable
170
+ await asyncio.sleep(polling_interval_s)
171
+ if task_state in READY_STATES:
172
+ await _handle_ready(task, client)
173
+
174
+
175
+ def _format_error(error: TaskError) -> str:
176
+ stack = StackSummary.from_list(
177
+ [FrameSummary(f.name, f.lineno, f.name) for f in error.stacktrace]
178
+ )
179
+ msg = f"{error.name}:\n{stack}\n{error.message}"
180
+ if error.cause:
181
+ msg += "\n cause by {error.cause}"
182
+ return msg
@@ -0,0 +1,33 @@
1
+ import asyncio
2
+ import concurrent.futures
3
+ import sys
4
+ from functools import wraps
5
+
6
+ import typer
7
+
8
+
9
+ class AsyncTyper(typer.Typer):
10
+ def async_command(self, *args, **kwargs):
11
+ def decorator(async_func):
12
+ @wraps(async_func)
13
+ def sync_func(*_args, **_kwargs):
14
+ res = asyncio.run(async_func(*_args, **_kwargs))
15
+ return res
16
+
17
+ self.command(*args, **kwargs)(sync_func)
18
+ return async_func
19
+
20
+ return decorator
21
+
22
+
23
+ def eprint(*args, **kwargs):
24
+ print(*args, file=sys.stderr, **kwargs)
25
+
26
+
27
+ def _to_concurrent(
28
+ fut: asyncio.Future, loop: asyncio.AbstractEventLoop
29
+ ) -> concurrent.futures.Future:
30
+ async def wait():
31
+ await fut
32
+
33
+ return asyncio.run_coroutine_threadsafe(wait(), loop)
@@ -0,0 +1,60 @@
1
+ from typing import ClassVar
2
+
3
+ from icij_common.pydantic_utils import ICIJSettings, NoEnumModel
4
+ from icij_worker.utils.logging_ import LogWithWorkerIDMixin
5
+ from pydantic import Field
6
+
7
+ import datashare_python
8
+
9
+ _ALL_LOGGERS = [datashare_python.__name__]
10
+
11
+
12
+ class AppConfig(ICIJSettings, LogWithWorkerIDMixin, NoEnumModel):
13
+ class Config:
14
+ env_prefix = "DS_DOCKER_ML_"
15
+
16
+ loggers: ClassVar[list[str]] = Field(_ALL_LOGGERS, const=True)
17
+
18
+ log_level: str = Field(default="INFO")
19
+
20
+ batch_size: int = 1024
21
+ pipeline_batch_size: int = 1024
22
+ ne_buffer_size: int = 1000
23
+
24
+ # DS
25
+ ds_api_key: str | None = None
26
+ ds_url: str = "http://datashare:8080"
27
+ # ES
28
+ es_address: str = "http://localhost:9200"
29
+ es_default_page_size: int = 1000
30
+ es_keep_alive: str = "10m"
31
+ es_max_concurrency: int = 5
32
+ es_max_retries: int = 0
33
+ es_max_retry_wait_s: int | float = 60
34
+ es_timeout_s: int | float = 60 * 5
35
+
36
+ def to_es_client(self, address: str | None = None) -> "ESClient":
37
+ from icij_common.es import ESClient
38
+
39
+ if address is None:
40
+ address = self.es_address
41
+
42
+ client = ESClient(
43
+ hosts=[address],
44
+ pagination=self.es_default_page_size,
45
+ max_concurrency=self.es_max_concurrency,
46
+ keep_alive=self.es_keep_alive,
47
+ timeout=self.es_timeout_s,
48
+ max_retries=self.es_max_retries,
49
+ max_retry_wait_s=self.es_max_retry_wait_s,
50
+ api_key=self.ds_api_key,
51
+ )
52
+ client.transport._verified_elasticsearch = ( # pylint: disable=protected-access
53
+ True
54
+ )
55
+ return client
56
+
57
+ def to_task_client(self) -> "DatashareTaskClient":
58
+ from datashare_python.task_client import DatashareTaskClient
59
+
60
+ return DatashareTaskClient(self.ds_url)
@@ -0,0 +1,6 @@
1
+ from pathlib import Path
2
+
3
+ from icij_worker.app import TaskGroup
4
+
5
+ DATA_DIR = Path(__file__).parent.joinpath(".data")
6
+ PYTHON_TASK_GROUP = TaskGroup(name="PYTHON")
@@ -0,0 +1,49 @@
1
+ from typing import Self
2
+
3
+ import pycountry
4
+ from icij_common.es import DOC_CONTENT, DOC_LANGUAGE, DOC_ROOT_ID, ID_, SOURCE
5
+ from icij_common.pydantic_utils import ICIJModel, LowerCamelCaseModel
6
+ from pydantic import Field
7
+
8
+
9
+ class Document(LowerCamelCaseModel):
10
+ id: str
11
+ root_document: str
12
+ content: str
13
+ language: str
14
+ tags: list[str] = Field(default_factory=list)
15
+ content_translated: dict[str, str] = Field(
16
+ default_factory=dict, alias="content_translated"
17
+ )
18
+
19
+ @classmethod
20
+ def from_es(cls, es_doc: dict) -> Self:
21
+ sources = es_doc[SOURCE]
22
+ return cls(
23
+ id=es_doc[ID_],
24
+ content=sources[DOC_CONTENT],
25
+ content_translated=sources.get("content_translated", dict()),
26
+ language=sources[DOC_LANGUAGE],
27
+ root_document=sources[DOC_ROOT_ID],
28
+ tags=sources.get("tags", []),
29
+ )
30
+
31
+
32
+ class ClassificationConfig(ICIJModel):
33
+ task: str = Field(const=True, default="text-classification")
34
+ model: str = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
35
+ batch_size: int = 16
36
+
37
+
38
+ class TranslationConfig(ICIJModel):
39
+ task: str = Field(const=True, default="translation")
40
+ model: str = "Helsinki-NLP/opus-mt"
41
+ batch_size: int = 16
42
+
43
+ def to_pipeline_args(self, source_language: str, *, target_language: str) -> dict:
44
+ as_dict = self.dict()
45
+ source_alpha2 = pycountry.languages.get(name=source_language).alpha_2
46
+ target_alpha2 = pycountry.languages.get(name=target_language).alpha_2
47
+ as_dict["task"] = f"translation_{source_alpha2}_to_{target_alpha2}"
48
+ as_dict["model"] = f"{self.model}-{source_alpha2}-{target_alpha2}"
49
+ return as_dict