waldiez 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of waldiez might be problematic. Click here for more details.

Files changed (53) hide show
  1. waldiez/__init__.py +15 -66
  2. waldiez/_version.py +1 -1
  3. waldiez/cli.py +11 -8
  4. waldiez/exporting/__init__.py +2 -0
  5. waldiez/exporting/agent/agent_exporter.py +11 -2
  6. waldiez/exporting/agent/utils/__init__.py +2 -0
  7. waldiez/exporting/agent/utils/agent_class_name.py +2 -0
  8. waldiez/exporting/agent/utils/agent_imports.py +5 -0
  9. waldiez/exporting/agent/utils/reasoning.py +36 -0
  10. waldiez/exporting/flow/flow_exporter.py +21 -8
  11. waldiez/exporting/flow/utils/__init__.py +10 -5
  12. waldiez/exporting/flow/utils/def_main.py +25 -20
  13. waldiez/exporting/flow/utils/flow_content.py +42 -1
  14. waldiez/exporting/flow/utils/importing_utils.py +7 -1
  15. waldiez/exporting/flow/utils/logging_utils.py +176 -42
  16. waldiez/models/__init__.py +8 -0
  17. waldiez/models/agents/__init__.py +10 -0
  18. waldiez/models/agents/agent/agent.py +10 -4
  19. waldiez/models/agents/agent/termination_message.py +2 -0
  20. waldiez/models/agents/agents.py +10 -0
  21. waldiez/models/agents/rag_user/retrieve_config.py +46 -17
  22. waldiez/models/agents/reasoning/__init__.py +13 -0
  23. waldiez/models/agents/reasoning/reasoning_agent.py +43 -0
  24. waldiez/models/agents/reasoning/reasoning_agent_data.py +116 -0
  25. waldiez/models/agents/reasoning/reasoning_agent_reason_config.py +101 -0
  26. waldiez/models/agents/swarm_agent/__init__.py +2 -1
  27. waldiez/models/agents/swarm_agent/swarm_agent_data.py +2 -3
  28. waldiez/models/chat/chat_data.py +30 -63
  29. waldiez/models/chat/chat_message.py +2 -26
  30. waldiez/models/chat/chat_nested.py +7 -8
  31. waldiez/models/common/__init__.py +3 -18
  32. waldiez/models/common/date_utils.py +18 -0
  33. waldiez/models/common/dict_utils.py +37 -0
  34. waldiez/models/common/method_utils.py +2 -5
  35. waldiez/models/flow/flow_data.py +1 -1
  36. waldiez/models/waldiez.py +4 -1
  37. waldiez/runner.py +3 -3
  38. waldiez/running/environment.py +22 -16
  39. waldiez/running/gen_seq_diagram.py +7 -4
  40. waldiez/running/running.py +67 -19
  41. waldiez/utils/__init__.py +15 -0
  42. waldiez/utils/cli_extras/__init__.py +30 -0
  43. waldiez/{cli_extras.py → utils/cli_extras/jupyter.py} +9 -20
  44. waldiez/utils/cli_extras/studio.py +36 -0
  45. waldiez/{conflict_checker.py → utils/conflict_checker.py} +14 -3
  46. waldiez/utils/flaml_warnings.py +17 -0
  47. waldiez/utils/pysqlite3_checker.py +249 -0
  48. {waldiez-0.3.6.dist-info → waldiez-0.3.7.dist-info}/METADATA +27 -19
  49. {waldiez-0.3.6.dist-info → waldiez-0.3.7.dist-info}/RECORD +53 -40
  50. waldiez-0.3.7.dist-info/licenses/NOTICE.md +5 -0
  51. {waldiez-0.3.6.dist-info → waldiez-0.3.7.dist-info}/WHEEL +0 -0
  52. {waldiez-0.3.6.dist-info → waldiez-0.3.7.dist-info}/entry_points.txt +0 -0
  53. {waldiez-0.3.6.dist-info → waldiez-0.3.7.dist-info}/licenses/LICENSE +0 -0
@@ -26,24 +26,30 @@ def in_virtualenv() -> bool:
26
26
 
27
27
  def refresh_environment() -> None:
28
28
  """Refresh the environment."""
29
- # backup the default IOStream
30
- from autogen.io import IOStream # type: ignore
29
+ with warnings.catch_warnings():
30
+ warnings.filterwarnings(
31
+ "ignore",
32
+ module="flaml",
33
+ message="^.*flaml.automl is not available.*$",
34
+ )
35
+ from autogen.io import IOStream # type: ignore
31
36
 
32
- default_io_stream = IOStream.get_default()
33
- site.main()
34
- # pylint: disable=import-outside-toplevel
35
- modules_to_reload = [mod for mod in sys.modules if "autogen" in mod]
36
- for mod in modules_to_reload:
37
- del sys.modules[mod]
38
- warnings.filterwarnings(
39
- "ignore", module="flaml", message="^.*flaml.automl is not available.*$"
40
- )
41
- import autogen # type: ignore
42
- from autogen.io import IOStream
37
+ default_io_stream = IOStream.get_default()
38
+ site.main()
39
+ # pylint: disable=import-outside-toplevel
40
+ modules_to_reload = [mod for mod in sys.modules if "autogen" in mod]
41
+ for mod in modules_to_reload:
42
+ del sys.modules[mod]
43
+ import autogen # type: ignore
44
+ from autogen.io import IOStream
43
45
 
44
- importlib.reload(autogen)
45
- # restore the default IOStream
46
- IOStream.set_global_default(default_io_stream)
46
+ importlib.reload(autogen)
47
+ # restore the default IOStream
48
+ IOStream.set_global_default(default_io_stream)
49
+ # reload any other modules that may have been affected
50
+ for mod in modules_to_reload:
51
+ if mod not in sys.modules:
52
+ importlib.import_module(mod)
47
53
 
48
54
 
49
55
  def set_env_vars(flow_env_vars: List[Tuple[str, str]]) -> Dict[str, str]:
@@ -173,10 +173,13 @@ def generate_sequence_diagram(
173
173
  if file_path.suffix not in [".json", ".csv"]:
174
174
  raise ValueError("Input file must be a JSON or CSV file.")
175
175
  is_csv = file_path.suffix == ".csv"
176
- if is_csv:
177
- df_events = pd.read_csv(file_path)
178
- else:
179
- df_events = pd.read_json(file_path)
176
+ try:
177
+ if is_csv:
178
+ df_events = pd.read_csv(file_path)
179
+ else:
180
+ df_events = pd.read_json(file_path)
181
+ except pd.errors.EmptyDataError:
182
+ return
180
183
 
181
184
  # Generate the Mermaid sequence diagram text
182
185
  mermaid_text = process_events(df_events)
@@ -10,6 +10,7 @@ import shutil
10
10
  import subprocess
11
11
  import sys
12
12
  import tempfile
13
+ import warnings
13
14
  from contextlib import asynccontextmanager, contextmanager
14
15
  from pathlib import Path
15
16
  from typing import (
@@ -157,7 +158,7 @@ async def a_install_requirements(
157
158
  printer(f"Installing requirements: {requirements_string}")
158
159
  pip_install = [sys.executable, "-m", "pip", "install"]
159
160
  if not in_virtualenv():
160
- pip_install.append("--user")
161
+ pip_install.extend(["--user", "--break-system-packages"])
161
162
  pip_install.extend(extra_requirements)
162
163
  proc = await asyncio.create_subprocess_exec(
163
164
  *pip_install,
@@ -198,15 +199,14 @@ def after_run(
198
199
  if isinstance(output_path, str):
199
200
  output_path = Path(output_path)
200
201
  output_dir = output_path.parent if output_path else Path.cwd()
201
- if output_dir.is_file():
202
- output_dir = output_dir.parent
203
202
  if skip_mmd is False:
204
203
  events_csv_path = temp_dir / "logs" / "events.csv"
205
204
  if events_csv_path.exists():
206
205
  printer("Generating mermaid sequence diagram...")
207
206
  mmd_path = temp_dir / f"{flow_name}.mmd"
208
207
  generate_sequence_diagram(events_csv_path, mmd_path)
209
- shutil.copyfile(mmd_path, output_dir / f"{flow_name}.mmd")
208
+ if mmd_path.exists():
209
+ shutil.copyfile(mmd_path, output_dir / f"{flow_name}.mmd")
210
210
  if output_path:
211
211
  destination_dir = output_path.parent
212
212
  destination_dir = (
@@ -217,23 +217,65 @@ def after_run(
217
217
  destination_dir.mkdir(parents=True, exist_ok=True)
218
218
  # copy the contents of the temp dir to the destination dir
219
219
  printer(f"Copying the results to {destination_dir}")
220
- for item in temp_dir.iterdir():
221
- # skip cache files
222
- if (
223
- item.name.startswith("__pycache__")
224
- or item.name.endswith(".pyc")
225
- or item.name.endswith(".pyo")
226
- or item.name.endswith(".pyd")
227
- or item.name == ".cache"
228
- ):
229
- continue
230
- if item.is_file():
231
- shutil.copy(item, destination_dir)
232
- else:
233
- shutil.copytree(item, destination_dir / item.name)
220
+ copy_results(
221
+ temp_dir=temp_dir,
222
+ output_path=output_path,
223
+ output_dir=output_dir,
224
+ destination_dir=destination_dir,
225
+ )
234
226
  shutil.rmtree(temp_dir)
235
227
 
236
228
 
229
+ def copy_results(
230
+ temp_dir: Path,
231
+ output_path: Path,
232
+ output_dir: Path,
233
+ destination_dir: Path,
234
+ ) -> None:
235
+ """Copy the results to the output directory.
236
+
237
+ Parameters
238
+ ----------
239
+ temp_dir : Path
240
+ The temporary directory.
241
+ output_path : Path
242
+ The output path.
243
+ output_dir : Path
244
+ The output directory.
245
+ destination_dir : Path
246
+ The destination directory.
247
+ """
248
+ temp_dir.mkdir(parents=True, exist_ok=True)
249
+ for item in temp_dir.iterdir():
250
+ # skip cache files
251
+ if (
252
+ item.name.startswith("__pycache__")
253
+ or item.name.endswith(".pyc")
254
+ or item.name.endswith(".pyo")
255
+ or item.name.endswith(".pyd")
256
+ or item.name == ".cache"
257
+ ):
258
+ continue
259
+ if item.is_file():
260
+ # let's also copy the tree of thoughts image
261
+ # to the output directory
262
+ if item.name.endswith("tree_of_thoughts.png"):
263
+ shutil.copy(item, output_dir / item.name)
264
+ shutil.copy(item, destination_dir)
265
+ else:
266
+ shutil.copytree(item, destination_dir / item.name)
267
+ if output_path.is_file():
268
+ if output_path.suffix == ".waldiez":
269
+ output_path = output_path.with_suffix(".py")
270
+ if output_path.suffix == ".py":
271
+ src = temp_dir / output_path.name
272
+ if src.exists():
273
+ dst = destination_dir / output_path.name
274
+ if dst.exists():
275
+ dst.unlink()
276
+ shutil.copyfile(src, output_dir / output_path.name)
277
+
278
+
237
279
  def get_printer() -> Callable[..., None]:
238
280
  """Get the printer function.
239
281
 
@@ -242,7 +284,13 @@ def get_printer() -> Callable[..., None]:
242
284
  Callable[..., None]
243
285
  The printer function.
244
286
  """
245
- from autogen.io import IOStream # type: ignore
287
+ with warnings.catch_warnings():
288
+ warnings.filterwarnings(
289
+ "ignore",
290
+ module="flaml",
291
+ message="^.*flaml.automl is not available.*$",
292
+ )
293
+ from autogen.io import IOStream # type: ignore
246
294
 
247
295
  printer = IOStream.get_default().print
248
296
 
@@ -0,0 +1,15 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Utils to call on init."""
4
+
5
+ from .cli_extras import add_cli_extras
6
+ from .conflict_checker import check_conflicts
7
+ from .flaml_warnings import check_flaml_warnings
8
+ from .pysqlite3_checker import check_pysqlite3
9
+
10
+ __all__ = [
11
+ "check_conflicts",
12
+ "check_flaml_warnings",
13
+ "add_cli_extras",
14
+ "check_pysqlite3",
15
+ ]
@@ -0,0 +1,30 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ # pylint: skip-file
4
+ # isort: skip_file
5
+ """Extra typer commands for CLI."""
6
+
7
+ import typer
8
+
9
+ from .jupyter import add_jupyter_cli
10
+ from .studio import add_studio_cli
11
+
12
+
13
+ def add_cli_extras(app: typer.Typer) -> None:
14
+ """Add extra CLI commands to the app.
15
+
16
+ Parameters
17
+ ----------
18
+ app : typer.Typer
19
+ The Typer app to add the extra commands to.
20
+
21
+ Returns
22
+ -------
23
+ typer.Typer
24
+ The app with the extra commands added
25
+ """
26
+ add_jupyter_cli(app)
27
+ add_studio_cli(app)
28
+
29
+
30
+ __all__ = ["add_cli_extras"]
@@ -1,9 +1,9 @@
1
1
  # SPDX-License-Identifier: Apache-2.0.
2
2
  # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
3
  # pylint: skip-file
4
- # type: ignore
5
4
  # isort: skip_file
6
- """Extra typer commands for CLI."""
5
+ # flake8: noqa: E501
6
+ """Waldiez-jupyter extra typer commands for CLI."""
7
7
 
8
8
  from typing import Callable
9
9
 
@@ -11,40 +11,29 @@ import typer
11
11
  from typer.models import CommandInfo
12
12
  import subprocess # nosemgrep # nosec
13
13
 
14
- HAVE_STUDIO = False
15
14
  HAVE_JUPYTER = False
16
- try:
17
- from waldiez_studio.cli import run as studio_app
18
-
19
- HAVE_STUDIO = True
20
- except BaseException:
21
- pass
22
15
 
23
16
  try:
24
- import waldiez_jupyter # noqa: F401
17
+ import waldiez_jupyter # type: ignore[unused-ignore, unused-import, import-not-found, import-untyped] # noqa
25
18
 
26
19
  HAVE_JUPYTER = True
27
20
  except BaseException:
28
21
  pass
29
22
 
30
23
 
31
- def add_cli_extras(app: typer.Typer) -> None:
32
- """Add extra CLI commands to the app.
24
+ def add_jupyter_cli(app: typer.Typer) -> None:
25
+ """Add Jupyter extra command to the app if available.
33
26
 
34
27
  Parameters
35
28
  ----------
36
29
  app : typer.Typer
37
- The Typer app to add the extra commands to.
30
+ The Typer app to add the extra command to.
38
31
 
39
32
  Returns
40
33
  -------
41
34
  typer.Typer
42
- The app with the extra commands added
35
+ The app with the extra command added
43
36
  """
44
- if HAVE_STUDIO:
45
- app.registered_commands.append(
46
- CommandInfo(name="studio", callback=studio_app)
47
- )
48
37
  if HAVE_JUPYTER:
49
38
  jupyter_app = get_jupyter_app()
50
39
  app.registered_commands.append(
@@ -112,9 +101,9 @@ def get_jupyter_app() -> Callable[..., None]:
112
101
  if not browser:
113
102
  command.append("--no-browser")
114
103
  if password:
115
- from jupyter_server.auth import passwd
104
+ from jupyter_server.auth import passwd # type: ignore[unused-ignore, import-not-found, attr-defined] # noqa
116
105
 
117
- hashed_password = passwd(password)
106
+ hashed_password = passwd(password) # type: ignore[unused-ignore, no-untyped-call] # noqa
118
107
  command.append(f"--ServerApp.password={hashed_password}")
119
108
  subprocess.run(command)
120
109
 
@@ -0,0 +1,36 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ # pylint: skip-file
4
+ # isort: skip_file
5
+ """Waldiez-studio extra typer commands for CLI."""
6
+
7
+ from typing import Any, Callable
8
+
9
+ import typer
10
+ from typer.models import CommandInfo
11
+
12
+ HAVE_STUDIO = False
13
+ studio_app: Callable[..., Any] | None = None
14
+
15
+ try:
16
+ from waldiez_studio.cli import run # type: ignore[unused-ignore, import-untyped, import-not-found] # noqa
17
+
18
+ studio_app = run
19
+
20
+ HAVE_STUDIO = True
21
+ except BaseException:
22
+ pass
23
+
24
+
25
+ def add_studio_cli(app: typer.Typer) -> None:
26
+ """Add studio command to the app if available.
27
+
28
+ Parameters
29
+ ----------
30
+ app : typer.Typer
31
+ The Typer app to add the studio command to.
32
+ """
33
+ if HAVE_STUDIO:
34
+ app.registered_commands.append(
35
+ CommandInfo(name="studio", callback=studio_app)
36
+ )
@@ -6,9 +6,11 @@
6
6
  import sys
7
7
  from importlib.metadata import PackageNotFoundError, version
8
8
 
9
+ __WALDIEZ_CHECKED_FOR_CONFLICTS = False
10
+
9
11
 
10
12
  # fmt: off
11
- def check_conflicts() -> None: # pragma: no cover
13
+ def _check_conflicts() -> None: # pragma: no cover
12
14
  """Check for conflicts with 'autogen-agentchat' package."""
13
15
  try:
14
16
  version("autogen-agentchat")
@@ -19,10 +21,19 @@ def check_conflicts() -> None: # pragma: no cover
19
21
  "Please uninstall 'autogen-agentchat': \n"
20
22
  f"{sys.executable} -m pip uninstall -y autogen-agentchat" + "\n"
21
23
  "And install 'pyautogen' (and/or 'waldiez') again: \n"
22
- f"{sys.executable} -m pip install --force pyautogen waldiez"
24
+ f"{sys.executable} -m pip install --force pyautogen waldiez",
25
+ file=sys.stderr,
23
26
  )
24
27
  sys.exit(1)
25
28
  except PackageNotFoundError:
26
29
  pass
27
-
28
30
  # fmt: on
31
+
32
+
33
+ def check_conflicts() -> None: # pragma: no cover
34
+ """Check for conflicts with 'autogen-agentchat' package."""
35
+ # pylint: disable=global-statement
36
+ global __WALDIEZ_CHECKED_FOR_CONFLICTS
37
+ if __WALDIEZ_CHECKED_FOR_CONFLICTS is False:
38
+ _check_conflicts()
39
+ __WALDIEZ_CHECKED_FOR_CONFLICTS = True
@@ -0,0 +1,17 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ """Try to suppress the annoying flaml.automl not being available warning."""
4
+
5
+ import logging
6
+
7
+ __WALDIEZ_CHECKED_FLAML_WARNINGS = False
8
+
9
+
10
+ def check_flaml_warnings() -> None: # pragma: no cover
11
+ """Check for flaml warnings once."""
12
+ # pylint: disable=global-statement
13
+ global __WALDIEZ_CHECKED_FLAML_WARNINGS
14
+ if __WALDIEZ_CHECKED_FLAML_WARNINGS is False:
15
+ flam_logger = logging.getLogger("flaml")
16
+ flam_logger.setLevel(logging.ERROR)
17
+ __WALDIEZ_CHECKED_FLAML_WARNINGS = True
@@ -0,0 +1,249 @@
1
+ # SPDX-License-Identifier: Apache-2.0.
2
+ # Copyright (c) 2024 - 2025 Waldiez and contributors.
3
+ # flake8: noqa: E501
4
+ """Try to install pysqlite3-binary.
5
+
6
+ Highly recommended to be run in a virtual environment.
7
+ 'setuptools' and 'wheel' will also be installed if not already installed.
8
+ """
9
+
10
+ import os
11
+ import shutil
12
+ import site
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+ import urllib.request
17
+ import zipfile
18
+
19
+ PYSQLITE3_VERSION = "0.5.4"
20
+ SQLITE_URL = "https://www.sqlite.org/2025/sqlite-amalgamation-3480000.zip"
21
+ PYSQLITE3_URL = f"https://github.com/coleifer/pysqlite3/archive/refs/tags/{PYSQLITE3_VERSION}.zip" # pylint: disable=line-too-long
22
+
23
+ PIP = f"{sys.executable} -m pip"
24
+
25
+
26
+ def run_command(command: str, cwd: str = ".") -> None:
27
+ """Run a command.
28
+
29
+ Parameters
30
+ ----------
31
+ command : str
32
+ The command to run.
33
+ cwd : str
34
+ The current working directory.
35
+ """
36
+ cmd_parts = command.split(" ")
37
+ if cwd == ".":
38
+ cwd = os.getcwd()
39
+ try:
40
+ subprocess.run( # nosemgrep # nosec
41
+ cmd_parts,
42
+ check=True,
43
+ cwd=cwd,
44
+ env=os.environ,
45
+ encoding="utf-8",
46
+ )
47
+ except subprocess.CalledProcessError as e:
48
+ print(f"Error running command: {e}")
49
+ sys.exit(1)
50
+
51
+
52
+ def in_virtualenv() -> bool:
53
+ """Check if the script is running in a virtual environment.
54
+
55
+ Returns
56
+ -------
57
+ bool
58
+ True if in a virtual environment, False otherwise.
59
+ """
60
+ return hasattr(sys, "real_prefix") or (
61
+ hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix
62
+ )
63
+
64
+
65
+ def pip_install(*package_names: str, cwd: str = ".") -> None:
66
+ """Install packages using pip.
67
+
68
+ Parameters
69
+ ----------
70
+ *package_names : tuple[str, ...]
71
+ The package names or paths to install.
72
+ cwd : str
73
+ The current working directory.
74
+ """
75
+ args = "-qq"
76
+ if not in_virtualenv():
77
+ args += " --user --break-system-packages"
78
+ package_names_str = " ".join(package_names)
79
+ run_command(f"{PIP} install {args} {package_names_str}", cwd)
80
+
81
+
82
+ def pip_uninstall(*package_names: str, cwd: str = ".") -> None:
83
+ """Uninstall packages using pip.
84
+
85
+ Parameters
86
+ ----------
87
+ *package_names : tuple[str, ...]
88
+ The package names to uninstall.
89
+ cwd : str
90
+ The current working directory.
91
+ """
92
+ args = "-qq --yes"
93
+ if not in_virtualenv():
94
+ args += " --break-system-packages"
95
+ package_names_str = " ".join(package_names)
96
+ run_command(f"{PIP} uninstall {args} {package_names_str}", cwd)
97
+
98
+
99
+ def download_sqlite_amalgamation() -> str:
100
+ """Download the SQLite amalgamation source code.
101
+
102
+ Returns
103
+ -------
104
+ str
105
+ The path to the extracted SQLite source code.
106
+ """
107
+ zip_path = "sqlite_amalgamation.zip"
108
+ extract_path = "sqlite_amalgamation"
109
+
110
+ # Download the SQLite source code
111
+ print("Downloading SQLite amalgamation source code...")
112
+ urllib.request.urlretrieve(SQLITE_URL, zip_path) # nosec
113
+
114
+ # Extract the SQLite source code
115
+ print("Extracting SQLite source code...")
116
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
117
+ zip_ref.extractall(extract_path)
118
+
119
+ # Clean up the zip file
120
+ os.remove(zip_path)
121
+
122
+ # Return the path to the extracted source code
123
+ return os.path.join(extract_path, "sqlite-amalgamation-3480000")
124
+
125
+
126
+ def rename_package_name(pysqlite3_dir: str) -> None:
127
+ """Rename the package name in the setup.py file.
128
+
129
+ Parameters
130
+ ----------
131
+ pysqlite3_dir : str
132
+ The path to the pysqlite3 directory.
133
+ """
134
+ setup_file = os.path.join(pysqlite3_dir, "setup.py")
135
+ with open(setup_file, "r", encoding="utf-8") as file:
136
+ setup_py = file.read()
137
+ # sed -i "s|name='pysqlite3-binary'|name=PACKAGE_NAME|g" setup.py
138
+ setup_py = setup_py.replace(
139
+ "name='pysqlite3'", "name='pysqlite3-binary'"
140
+ ).replace("name=PACKAGE_NAME,", "name='pysqlite3-binary',")
141
+ with open(setup_file, "w", encoding="utf-8", newline="\n") as file:
142
+ file.write(setup_py)
143
+
144
+
145
+ def prepare_pysqlite3(sqlite_amalgamation_path: str) -> str:
146
+ """Prepare pysqlite3 using the SQLite amalgamation source code.
147
+
148
+ Parameters
149
+ ----------
150
+ sqlite_amalgamation_path : str
151
+ The path to the SQLite amalgamation source code.
152
+
153
+ Returns
154
+ -------
155
+ str
156
+ The path to the pysqlite3 directory.
157
+ """
158
+ pysqlite3_zip = "pysqlite3.zip"
159
+ pysqlite3_extract = "pysqlite3"
160
+ urllib.request.urlretrieve(PYSQLITE3_URL, pysqlite3_zip) # nosec
161
+ with zipfile.ZipFile(pysqlite3_zip, "r") as zip_ref:
162
+ zip_ref.extractall(pysqlite3_extract)
163
+ os.remove(pysqlite3_zip)
164
+ sqlite3_c = os.path.join(sqlite_amalgamation_path, "sqlite3.c")
165
+ sqlite3_h = os.path.join(sqlite_amalgamation_path, "sqlite3.h")
166
+ pysqlite3_dir = os.path.join(
167
+ pysqlite3_extract, f"pysqlite3-{PYSQLITE3_VERSION}"
168
+ )
169
+ shutil.copy(sqlite3_c, pysqlite3_dir)
170
+ shutil.copy(sqlite3_h, pysqlite3_dir)
171
+ rename_package_name(pysqlite3_dir)
172
+ return pysqlite3_dir
173
+
174
+
175
+ def install_pysqlite3(sqlite_amalgamation_path: str) -> None:
176
+ """Install pysqlite3 using the SQLite amalgamation source code.
177
+
178
+ Parameters
179
+ ----------
180
+ sqlite_amalgamation_path : str
181
+ The path to the SQLite amalgamation source code.
182
+ """
183
+ # pylint: disable=too-many-try-statements
184
+ try:
185
+ pysqlite3_dir = prepare_pysqlite3(sqlite_amalgamation_path)
186
+ pip_install("setuptools")
187
+ run_command(f"{sys.executable} setup.py build_static", pysqlite3_dir)
188
+ pip_install("wheel")
189
+ run_command(
190
+ f"{PIP} wheel . -w dist",
191
+ pysqlite3_dir,
192
+ )
193
+ wheel_file = os.listdir(os.path.join(pysqlite3_dir, "dist"))[0]
194
+ wheel_path = os.path.join("dist", wheel_file)
195
+ pip_install(wheel_path, cwd=pysqlite3_dir)
196
+ except BaseException as e: # pylint: disable=broad-except
197
+ print(f"Failed to install pysqlite3: {e}")
198
+ sys.exit(1)
199
+
200
+
201
+ def check_pysqlite3() -> None:
202
+ """Check the installation of pysqlite3."""
203
+ # pylint: disable=unused-import, import-outside-toplevel, line-too-long
204
+ try:
205
+ import pysqlite3 # type: ignore[unused-ignore, import-untyped, import-not-found] # noqa
206
+ except ImportError:
207
+ print("pysqlite3 not found or cannot be imported.")
208
+ # Uninstall pysqlite3-binary if it is already installed
209
+ pip_uninstall("pysqlite3", "pysqlite3-binary")
210
+ source_path = download_sqlite_amalgamation()
211
+ install_pysqlite3(source_path)
212
+ site.main()
213
+ # Re-import pysqlite3 as sqlite3
214
+ import pysqlite3 # type: ignore[unused-ignore, import-untyped, import-not-found] # noqa
215
+
216
+
217
+ def test_sqlite_usage() -> None:
218
+ """Test the usage of the sqlite3 module."""
219
+ # pylint: disable=import-outside-toplevel
220
+ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
221
+ import sqlite3 # noqa
222
+
223
+ print(sqlite3.__file__)
224
+ # it should be sth like: /path/to/site-packages/pysqlite3/__init__.py
225
+ conn = sqlite3.connect(":memory:")
226
+ cursor = conn.cursor()
227
+ cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")
228
+ cursor.execute("INSERT INTO test (name) VALUES ('test')")
229
+ cursor.execute("SELECT * FROM test")
230
+ rows = cursor.fetchall()
231
+ print(rows)
232
+ conn.close()
233
+
234
+
235
+ def main() -> None:
236
+ """Run the check."""
237
+ if "--force" in sys.argv:
238
+ pip_uninstall("pysqlite3", "pysqlite3-binary")
239
+ cwd = os.getcwd()
240
+ tmpdir = tempfile.mkdtemp()
241
+ os.chdir(tmpdir)
242
+ check_pysqlite3()
243
+ os.chdir(cwd)
244
+ shutil.rmtree(tmpdir)
245
+ test_sqlite_usage()
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()