outlines-haystack 0.0.1a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,92 @@
1
+ name: 📢 Publish to PyPI and TestPyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "[0-9]+.[0-9]+.[0-9]+*"
7
+
8
+ jobs:
9
+ build:
10
+ name: 📦 Build distribution
11
+ runs-on: ubuntu-latest
12
+
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v5
17
+ with:
18
+ python-version: "3.9"
19
+
20
+ - name: Install Hatch
21
+ run: pip install --upgrade hatch
22
+
23
+ - name: Build a binary wheel and a source tarball
24
+ run: hatch build
25
+
26
+ - name: Store the distribution packages
27
+ uses: actions/upload-artifact@v4
28
+ with:
29
+ name: python-package-distributions
30
+ path: dist/
31
+
32
+
33
+ publish-to-pypi:
34
+ name: 🚀 Publish to PyPI
35
+ needs:
36
+ - build
37
+ runs-on: ubuntu-latest
38
+ environment:
39
+ name: pypi
40
+ url: https://pypi.org/p/dataframes-haystack/
41
+ permissions:
42
+ id-token: write # IMPORTANT: mandatory for trusted publishing
43
+
44
+ steps:
45
+ - name: Download `dist` from the artifacts
46
+ uses: actions/download-artifact@v4
47
+ with:
48
+ name: python-package-distributions
49
+ path: dist/
50
+ - name: Publish the distribution to PyPI
51
+ uses: pypa/gh-action-pypi-publish@release/v1
52
+
53
+
54
+ github-release:
55
+ name: 📝 Sign the distribution with Sigstore and upload to GitHub Release
56
+ needs:
57
+ - publish-to-pypi
58
+ runs-on: ubuntu-latest
59
+ permissions:
60
+ contents: write # IMPORTANT: mandatory for making GitHub Releases
61
+ id-token: write # IMPORTANT: mandatory for sigstore
62
+
63
+ steps:
64
+ - name: Download `dist` from the artifacts
65
+ uses: actions/download-artifact@v4
66
+ with:
67
+ name: python-package-distributions
68
+ path: dist/
69
+ - name: Sign the dists with Sigstore
70
+ uses: sigstore/gh-action-sigstore-python@v2.1.1
71
+ with:
72
+ inputs: >-
73
+ ./dist/*.tar.gz
74
+ ./dist/*.whl
75
+ - name: Create GitHub Release
76
+ env:
77
+ GITHUB_TOKEN: ${{ github.token }}
78
+ run: >-
79
+ gh release create
80
+ '${{ github.ref_name }}'
81
+ --repo '${{ github.repository }}'
82
+ --notes ""
83
+ - name: Upload artifact signatures to GitHub Release
84
+ env:
85
+ GITHUB_TOKEN: ${{ github.token }}
86
+ # Upload to GitHub Release using the `gh` CLI.
87
+ # `dist/` contains the built packages, and the
88
+ # sigstore-produced signatures and certificates.
89
+ run: >-
90
+ gh release upload
91
+ '${{ github.ref_name }}' dist/**
92
+ --repo '${{ github.repository }}'
@@ -0,0 +1,50 @@
1
+ name: 🔎 Run Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ concurrency:
10
+ group: test-${{ github.head_ref }}
11
+ cancel-in-progress: true
12
+
13
+ env:
14
+ PYTHONUNBUFFERED: "1"
15
+ FORCE_COLOR: "1"
16
+ UV_SYSTEM_PYTHON: "true"
17
+
18
+ jobs:
19
+ run:
20
+ name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }}
21
+ runs-on: ${{ matrix.os }}
22
+ strategy:
23
+ fail-fast: false
24
+ matrix:
25
+ os: [ubuntu-latest, windows-latest, macos-latest]
26
+ python-version: ['3.9', '3.10', '3.11', '3.12']
27
+
28
+ steps:
29
+ - uses: actions/checkout@v4
30
+
31
+ - name: Set up Python ${{ matrix.python-version }}
32
+ uses: actions/setup-python@v5
33
+ with:
34
+ python-version: ${{ matrix.python-version }}
35
+
36
+ - name: Install the latest version of uv
37
+ uses: astral-sh/setup-uv@v3
38
+ with:
39
+ version: "latest"
40
+
41
+ - name: Install Hatch
42
+ run: uv tool install --upgrade hatch
43
+ shell: bash
44
+
45
+ # - name: Run mypy types checking
46
+ # run: hatch run types:check
47
+
48
+ - name: Run unit tests
49
+ run: hatch run test-cov-all
50
+ shell: bash
@@ -0,0 +1,166 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # VSCode
163
+ .vscode/
164
+
165
+ # MacOS
166
+ .DS_Store
@@ -0,0 +1,31 @@
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: check-toml
8
+ - id: check-yaml
9
+ - id: check-json
10
+ - id: check-merge-conflict
11
+ args: ["--assume-in-merge"]
12
+ - id: check-added-large-files
13
+ args: ["--maxkb=1024"]
14
+ - id: debug-statements
15
+ - id: detect-private-key
16
+
17
+ - repo: https://github.com/psf/black
18
+ rev: 24.10.0
19
+ hooks:
20
+ - id: black
21
+
22
+ - repo: https://github.com/astral-sh/ruff-pre-commit
23
+ rev: v0.8.0
24
+ hooks:
25
+ - id: ruff
26
+ args: ["--fix"]
27
+
28
+ - repo: https://github.com/nbQA-dev/nbQA
29
+ rev: 1.9.1
30
+ hooks:
31
+ - id: nbqa-black
@@ -0,0 +1,9 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024-present Edoardo Abati
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1,61 @@
1
+ Metadata-Version: 2.3
2
+ Name: outlines-haystack
3
+ Version: 0.0.1a1
4
+ Summary: Haystack integration with outlines.
5
+ Project-URL: Documentation, https://github.com/EdAbati/outlines-haystack#readme
6
+ Project-URL: Issues, https://github.com/EdAbati/outlines-haystack/issues
7
+ Project-URL: Source, https://github.com/EdAbati/outlines-haystack
8
+ Author: Edoardo Abati
9
+ License: MIT License
10
+
11
+ Copyright (c) 2024-present Edoardo Abati
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18
+ Keywords: ai,generative-ai,haystack,llm,machine-learning,nlp,outlines,structured-generation
19
+ Classifier: Development Status :: 4 - Beta
20
+ Classifier: Programming Language :: Python
21
+ Classifier: Programming Language :: Python :: 3.9
22
+ Classifier: Programming Language :: Python :: 3.10
23
+ Classifier: Programming Language :: Python :: 3.11
24
+ Classifier: Programming Language :: Python :: 3.12
25
+ Classifier: Programming Language :: Python :: Implementation :: CPython
26
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
27
+ Requires-Python: >=3.9
28
+ Requires-Dist: haystack-ai>=2.5.0
29
+ Requires-Dist: outlines>=0.1.0
30
+ Provides-Extra: mlxlm
31
+ Requires-Dist: mlx; extra == 'mlxlm'
32
+ Requires-Dist: mlx-lm<0.19; extra == 'mlxlm'
33
+ Provides-Extra: openai
34
+ Requires-Dist: openai; extra == 'openai'
35
+ Provides-Extra: transformers
36
+ Requires-Dist: datasets; extra == 'transformers'
37
+ Requires-Dist: torch; extra == 'transformers'
38
+ Requires-Dist: transformers; extra == 'transformers'
39
+ Description-Content-Type: text/markdown
40
+
41
+ # `outlines-haystack`
42
+
43
+ [![PyPI - Version](https://img.shields.io/pypi/v/outlines-haystack.svg)](https://pypi.org/project/outlines-haystack)
44
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/outlines-haystack.svg)](https://pypi.org/project/outlines-haystack)
45
+
46
+ -----
47
+
48
+ ## Table of Contents
49
+
50
+ - [Installation](#installation)
51
+ - [License](#license)
52
+
53
+ ## Installation
54
+
55
+ ```console
56
+ pip install outlines-haystack
57
+ ```
58
+
59
+ ## License
60
+
61
+ `outlines-haystack` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
@@ -0,0 +1,21 @@
1
+ # `outlines-haystack`
2
+
3
+ [![PyPI - Version](https://img.shields.io/pypi/v/outlines-haystack.svg)](https://pypi.org/project/outlines-haystack)
4
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/outlines-haystack.svg)](https://pypi.org/project/outlines-haystack)
5
+
6
+ -----
7
+
8
+ ## Table of Contents
9
+
10
+ - [Installation](#installation)
11
+ - [License](#license)
12
+
13
+ ## Installation
14
+
15
+ ```console
16
+ pip install outlines-haystack
17
+ ```
18
+
19
+ ## License
20
+
21
+ `outlines-haystack` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
@@ -0,0 +1,172 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "outlines-haystack"
7
+ dynamic = ["version"]
8
+ description = 'Haystack integration with outlines.'
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = { file = "LICENSE" }
12
+ keywords = [
13
+ "nlp",
14
+ "machine-learning",
15
+ "ai",
16
+ "haystack",
17
+ "llm",
18
+ "outlines",
19
+ "structured-generation",
20
+ "generative-ai",
21
+ ]
22
+ authors = [{ name = "Edoardo Abati" }]
23
+ classifiers = [
24
+ "Development Status :: 4 - Beta",
25
+ "Programming Language :: Python",
26
+ "Programming Language :: Python :: 3.9",
27
+ "Programming Language :: Python :: 3.10",
28
+ "Programming Language :: Python :: 3.11",
29
+ "Programming Language :: Python :: 3.12",
30
+ "Programming Language :: Python :: Implementation :: CPython",
31
+ "Programming Language :: Python :: Implementation :: PyPy",
32
+ ]
33
+ dependencies = [
34
+ "haystack-ai>=2.5.0",
35
+ "outlines>=0.1.0",
36
+ ]
37
+ [project.optional-dependencies]
38
+ mlxlm = ["mlx", "mlx-lm<0.19"]
39
+ openai = ["openai"]
40
+ transformers = ["torch", "transformers", "datasets"]
41
+
42
+ [project.urls]
43
+ Documentation = "https://github.com/EdAbati/outlines-haystack#readme"
44
+ Issues = "https://github.com/EdAbati/outlines-haystack/issues"
45
+ Source = "https://github.com/EdAbati/outlines-haystack"
46
+
47
+ [tool.hatch.version]
48
+ path = "src/outlines_haystack/__about__.py"
49
+
50
+ # Default environment
51
+ [tool.hatch.envs.default]
52
+ installer = "uv"
53
+ dependencies = [
54
+ "coverage[toml]>=6.5",
55
+ "pytest",
56
+ ]
57
+ features = ["openai", "transformers"]
58
+ python = "3.9" # This is the minimum supported version
59
+
60
+ [tool.hatch.envs.default.scripts]
61
+ test = "pytest {args:tests}"
62
+ test-cov = "coverage run -m pytest {args:tests}"
63
+ cov-report = [
64
+ "- coverage combine",
65
+ "coverage report",
66
+ ]
67
+ test-cov-all = [
68
+ "test-cov",
69
+ "cov-report",
70
+ ]
71
+
72
+ [[tool.hatch.envs.all.matrix]]
73
+ python = ["3.9", "3.10", "3.11", "3.12"]
74
+
75
+
76
+ # Environment for documentation
77
+ [tool.hatch.envs.docs]
78
+ dependencies = [
79
+ "notebook",
80
+ "ipywidgets",
81
+ ]
82
+
83
+
84
+ # Type checking
85
+ [tool.hatch.envs.types]
86
+ extra-dependencies = [
87
+ "mypy>=1.0.0",
88
+ ]
89
+ [tool.hatch.envs.types.scripts]
90
+ check = "mypy --install-types --non-interactive {args:src/outlines_haystack tests}"
91
+
92
+
93
+ # Linting
94
+ [tool.hatch.envs.lint]
95
+ detached = true
96
+ dependencies = ["black>=24.3.0", "nbqa>=1.8.5", "ruff>=0.3.4"]
97
+ [tool.hatch.envs.lint.scripts]
98
+ style = [
99
+ "ruff check {args:.}",
100
+ "black --check --diff {args:.}",
101
+ "nbqa black --check --diff notebooks/*",
102
+ ]
103
+ fmt = [
104
+ "black {args:.}",
105
+ "ruff check --fix {args:.}",
106
+ "nbqa black notebooks/*",
107
+ "style",
108
+ ]
109
+
110
+
111
+ # Tools
112
+ [tool.black]
113
+ target-version = ["py39"]
114
+ line-length = 120
115
+ skip-string-normalization = true
116
+
117
+ [tool.ruff]
118
+ target-version = "py39"
119
+ line-length = 120
120
+ extend-include = ["*.ipynb"]
121
+
122
+ [tool.ruff.lint]
123
+ select = ["ALL"]
124
+ ignore = [
125
+ # No required doctstring for modules, packages
126
+ "D100",
127
+ "D104",
128
+ # No future annotations
129
+ "FA100",
130
+ # Ignore checks for possible passwords
131
+ "S105",
132
+ "S106",
133
+ "S107",
134
+ ]
135
+
136
+ [tool.ruff.lint.isort]
137
+ known-first-party = ["outlines_haystack"]
138
+
139
+ [tool.ruff.lint.flake8-tidy-imports]
140
+ ban-relative-imports = "all"
141
+
142
+ [tool.ruff.lint.pydocstyle]
143
+ convention = "google"
144
+
145
+ [tool.ruff.format]
146
+ docstring-code-format = true
147
+
148
+ [tool.ruff.lint.per-file-ignores]
149
+ # Tests can use magic values, assertions, and relative imports
150
+ "tests/*" = ["PLR2004", "S101", "TID252", "D100", "D103"]
151
+
152
+ [tool.coverage.run]
153
+ source_pkgs = ["outlines_haystack", "tests"]
154
+ branch = true
155
+ parallel = true
156
+ omit = [
157
+ "src/outlines_haystack/__about__.py",
158
+ ]
159
+
160
+ [tool.coverage.paths]
161
+ outlines_haystack = [
162
+ "src/outlines_haystack",
163
+ "*/outlines-haystack/src/outlines_haystack",
164
+ ]
165
+ tests = ["tests", "*/outlines-haystack/tests"]
166
+
167
+ [tool.coverage.report]
168
+ exclude_lines = [
169
+ "no cov",
170
+ "if __name__ == .__main__.:",
171
+ "if TYPE_CHECKING:",
172
+ ]
@@ -0,0 +1,4 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+ __version__ = "0.0.1a1"
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati
2
+ #
3
+ # SPDX-License-Identifier: MIT
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati
2
+ #
3
+ # SPDX-License-Identifier: MIT
@@ -0,0 +1,146 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ import os
6
+ from collections.abc import Mapping
7
+ from typing import Any, Optional, Union
8
+
9
+ from haystack import component, default_from_dict, default_to_dict
10
+ from haystack.utils import Secret, deserialize_secrets_inplace
11
+ from outlines import generate, models
12
+ from typing_extensions import Self
13
+
14
+
15
+ class _BaseAzureOpenAIGenerator:
16
+ def __init__( # noqa: PLR0913
17
+ self,
18
+ model_name: str,
19
+ azure_endpoint: Optional[str] = None,
20
+ azure_deployment: Optional[str] = None,
21
+ api_version: Optional[str] = None,
22
+ api_key: Secret = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False), # noqa: B008
23
+ azure_ad_token: Secret = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False), # noqa: B008
24
+ organization: Optional[str] = None,
25
+ project: Optional[str] = None,
26
+ timeout: Optional[int] = None,
27
+ max_retries: Optional[int] = None,
28
+ default_headers: Union[Mapping[str, str], None] = None,
29
+ default_query: Union[Mapping[str, str], None] = None,
30
+ ) -> None:
31
+ """Initialize the Azure OpenAI generator.
32
+
33
+ Args:
34
+ model_name: The name of the OpenAI model to use. The model name is needed to load the correct tokenizer for
35
+ the model. The tokenizer is necessary for structured generation. See https://dottxt-ai.github.io/outlines/latest/reference/models/openai/#azure-openai-models
36
+ azure_endpoint: The endpoint of the deployed model, for example `https://example-resource.azure.openai.com/`.
37
+ azure_deployment: A model deployment, if given sets the base client URL to include
38
+ `/deployments/{azure_deployment}`.
39
+ api_version: The API version to use for the Azure OpenAI API.
40
+ api_key: The Azure OpenAI API key. If not provided, uses the `OPENAI_API_KEY` environment variable.
41
+ azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
42
+ organization: The organization ID to use for the Azure OpenAI API. If not provided, uses `OPENAI_ORG_ID`
43
+ environment variable.
44
+ project: The project ID to use for the Azure OpenAI API. If not provided, uses `OPENAI_PROJECT_ID`
45
+ environment variable.
46
+ timeout: The timeout to use for the Azure OpenAI API. If not provided, uses `OPENAI_TIMEOUT` environment
47
+ variable. Defaults to 30.0.
48
+ max_retries: The maximum number of retries to use for the Azure OpenAI API. If not provided, uses
49
+ `OPENAI_MAX_RETRIES` environment variable. Defaults to 5.
50
+ default_headers: The default headers to use in the Azure OpenAI API client.
51
+ default_query: The default query parameters to use in the Azure OpenAI API client.
52
+ """
53
+ # Same defaults as in Haystack
54
+ # https://github.com/deepset-ai/haystack/blob/97126eb544be5bb7d1c5273e85597db6011b017c/haystack/components/generators/azure.py#L116-L125
55
+ azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
56
+ if not azure_endpoint:
57
+ msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT."
58
+ raise ValueError(msg)
59
+
60
+ if api_key is None and azure_ad_token is None:
61
+ msg = "Please provide an API key or an Azure Active Directory token."
62
+ raise ValueError(msg)
63
+
64
+ self.model_name = model_name
65
+ self.azure_endpoint = azure_endpoint
66
+ self.azure_deployment = azure_deployment
67
+ self.api_version = api_version
68
+ self.api_key = api_key
69
+ self.azure_ad_token = azure_ad_token
70
+ self.organization = organization
71
+ self.project = project
72
+
73
+ # https://github.com/deepset-ai/haystack/blob/97126eb544be5bb7d1c5273e85597db6011b017c/haystack/components/generators/azure.py#L139-L140
74
+ self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
75
+ self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
76
+
77
+ self.default_headers = default_headers
78
+ self.default_query = default_query
79
+
80
+ self.model = models.azure_openai(
81
+ deployment_name=self.azure_deployment,
82
+ model_name=self.model_name,
83
+ azure_endpoint=self.azure_endpoint,
84
+ api_version=self.api_version,
85
+ api_key=self.api_key,
86
+ azure_ad_token=self.azure_ad_token,
87
+ organization=self.organization,
88
+ project=self.project,
89
+ timeout=self.timeout,
90
+ max_retries=self.max_retries,
91
+ default_headers=self.default_headers,
92
+ default_query=self.default_query,
93
+ )
94
+
95
+ def to_dict(self) -> dict[str, Any]:
96
+ return default_to_dict(
97
+ self,
98
+ model_name=self.model_name,
99
+ azure_endpoint=self.azure_endpoint,
100
+ azure_deployment=self.azure_deployment,
101
+ api_version=self.api_version,
102
+ api_key=self.api_key.to_dict(),
103
+ azure_ad_token=self.azure_ad_token.to_dict(),
104
+ organization=self.organization,
105
+ project=self.project,
106
+ timeout=self.timeout,
107
+ max_retries=self.max_retries,
108
+ default_headers=self.default_headers,
109
+ default_query=self.default_query,
110
+ )
111
+
112
+ @classmethod
113
+ def from_dict(cls, data: dict[str, Any]) -> Self:
114
+ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
115
+ return default_from_dict(cls, data)
116
+
117
+
118
+ @component
119
+ class AzureOpenAITextGenerator(_BaseAzureOpenAIGenerator):
120
+ """A component that generates text using the Azure OpenAI API."""
121
+
122
+ @component.output_types(replies=list[str])
123
+ def run(
124
+ self,
125
+ prompt: str,
126
+ max_tokens: Optional[int] = None,
127
+ stop_at: Optional[Union[str, list[str]]] = None,
128
+ seed: Optional[int] = None,
129
+ ) -> dict[str, list[str]]:
130
+ """Run the generation component based on a prompt.
131
+
132
+ Args:
133
+ prompt: The prompt to use for generation.
134
+ max_tokens: The maximum number of tokens to generate.
135
+ stop_at: A string or list of strings after which to stop generation.
136
+ seed: The seed to use for generation.
137
+ """
138
+ if not prompt:
139
+ return {"replies": []}
140
+
141
+ if seed is not None:
142
+ self.model.config.seed = seed
143
+
144
+ generate_text_func = generate.text(self.model)
145
+ answer = generate_text_func(prompts=prompt, max_tokens=max_tokens, stop_at=stop_at)
146
+ return {"replies": [answer]}
@@ -0,0 +1,88 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ from typing import Any, Optional, Union
6
+
7
+ from haystack import component
8
+ from outlines import generate, models
9
+
10
+
11
+ class _BaseMLXLMGenerator:
12
+ def __init__(
13
+ self,
14
+ model_name: str,
15
+ tokenizer_config: Union[dict[str, Any], None] = None,
16
+ model_config: Union[dict[str, Any], None] = None,
17
+ adapter_path: Optional[str] = None,
18
+ lazy: bool = False, # noqa: FBT001, FBT002
19
+ ) -> None:
20
+ """Initialize the MLXLM generator component.
21
+
22
+ For more info, see https://dottxt-ai.github.io/outlines/latest/reference/models/mlxlm/#load-the-model
23
+
24
+ Args:
25
+ model_name: The path or the huggingface repository to load the model from.
26
+ tokenizer_config: Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary.
27
+ model_config: Configuration parameters specifically for the model. Defaults to an empty dictionary.
28
+ adapter_path: Path to the LoRA adapters. If provided, applies LoRA layers to the model. Default: None.
29
+ lazy: If False eval the model parameters to make sure they are loaded in memory before returning,
30
+ otherwise they will be loaded when needed. Default: False
31
+ """
32
+ self.model_name = model_name
33
+ self.tokenizer_config = tokenizer_config if tokenizer_config is not None else {}
34
+ self.model_config = model_config if model_config is not None else {}
35
+ self.adapter_path = adapter_path
36
+ self.lazy = lazy
37
+ self.model = None
38
+
39
+ @property
40
+ def _warmed_up(self) -> bool:
41
+ return self.model is not None
42
+
43
+ def warm_up(self) -> None:
44
+ """Initializes the component."""
45
+ if self._warmed_up:
46
+ return
47
+ self.model = models.mlxlm(
48
+ model_name=self.model_name,
49
+ tokenizer_config=self.tokenizer_config,
50
+ model_config=self.model_config,
51
+ adapter_path=self.adapter_path,
52
+ lazy=self.lazy,
53
+ )
54
+
55
+ def _check_component_warmed_up(self) -> None:
56
+ if not self._warmed_up:
57
+ msg = f"The component {self.__class__.__name__} was not warmed up. Please call warm_up() before running."
58
+ raise RuntimeError(msg)
59
+
60
+
61
+ @component
62
+ class MLXLMTextGenerator(_BaseMLXLMGenerator):
63
+ """A component for generating text using an MLXLM model."""
64
+
65
+ @component.output_types(replies=list[str])
66
+ def run(
67
+ self,
68
+ prompt: str,
69
+ max_tokens: Optional[int] = None,
70
+ stop_at: Optional[Union[str, list[str]]] = None,
71
+ seed: Optional[int] = None,
72
+ ) -> dict[str, list[str]]:
73
+ """Run the generation component based on a prompt.
74
+
75
+ Args:
76
+ prompt: The prompt to use for generation.
77
+ max_tokens: The maximum number of tokens to generate.
78
+ stop_at: A string or list of strings after which to stop generation.
79
+ seed: The seed to use for generation.
80
+ """
81
+ self._check_component_warmed_up()
82
+
83
+ if not prompt:
84
+ return {"replies": []}
85
+
86
+ generate_text_func = generate.text(self.model)
87
+ answer = generate_text_func(prompts=prompt, max_tokens=max_tokens, stop_at=stop_at, seed=seed)
88
+ return {"replies": [answer]}
@@ -0,0 +1,120 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ import os
6
+ from collections.abc import Mapping
7
+ from typing import Any, Optional, Union
8
+
9
+ from haystack import component, default_from_dict, default_to_dict
10
+ from haystack.utils import Secret, deserialize_secrets_inplace
11
+ from outlines import generate, models
12
+ from typing_extensions import Self
13
+
14
+
15
+ class _BaseOpenAIGenerator:
16
+ def __init__( # noqa: PLR0913
17
+ self,
18
+ model_name: str,
19
+ api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), # noqa: B008
20
+ organization: Optional[str] = None,
21
+ project: Optional[str] = None,
22
+ base_url: Optional[str] = None,
23
+ timeout: Optional[int] = None,
24
+ max_retries: Optional[int] = None,
25
+ default_headers: Union[Mapping[str, str], None] = None,
26
+ default_query: Union[Mapping[str, str], None] = None,
27
+ ) -> None:
28
+ """Initialize the OpenAI generator.
29
+
30
+ Args:
31
+ model_name: The name of the OpenAI model to use.
32
+ api_key: The OpenAI API key. If not provided, uses the `OPENAI_API_KEY` environment variable.
33
+ organization: The organization ID to use for the OpenAI API. If not provided, uses `OPENAI_ORG_ID`
34
+ environment variable.
35
+ project: The project ID to use for the OpenAI API. If not provided, uses `OPENAI_PROJECT_ID`
36
+ environment variable.
37
+ base_url: The base URL to use for the OpenAI API. If not provided, uses `OPENAI_BASE_URL` environment
38
+ variable.
39
+ timeout: The timeout to use for the OpenAI API. If not provided, uses `OPENAI_TIMEOUT` environment variable.
40
+ Defaults to 30.0.
41
+ max_retries: The maximum number of retries to use for the OpenAI API. If not provided, uses
42
+ `OPENAI_MAX_RETRIES` environment variable. Defaults to 5.
43
+ default_headers: The default headers to use in the OpenAI API client.
44
+ default_query: The default query parameters to use in the OpenAI API client.
45
+ """
46
+ self.model_name = model_name
47
+ self.api_key = api_key
48
+ self.organization = organization
49
+ self.project = project
50
+ self.base_url = base_url
51
+
52
+ # Same defaults as in Haystack
53
+ # https://github.com/deepset-ai/haystack/blob/3ef8c081be460a91f3c5c29899a6ee6bbc429caa/haystack/components/generators/openai.py#L114-L117
54
+ self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
55
+ self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
56
+
57
+ self.default_headers = default_headers
58
+ self.default_query = default_query
59
+
60
+ self.model = models.openai(
61
+ self.model_name,
62
+ api_key=self.api_key.resolve_value(),
63
+ organization=self.organization,
64
+ project=self.project,
65
+ base_url=self.base_url,
66
+ timeout=self.timeout,
67
+ max_retries=self.max_retries,
68
+ default_headers=self.default_headers,
69
+ default_query=self.default_query,
70
+ )
71
+
72
+ def to_dict(self) -> dict[str, Any]:
73
+ return default_to_dict(
74
+ self,
75
+ model_name=self.model_name,
76
+ api_key=self.api_key.to_dict(),
77
+ organization=self.organization,
78
+ project=self.project,
79
+ base_url=self.base_url,
80
+ timeout=self.timeout,
81
+ max_retries=self.max_retries,
82
+ default_headers=self.default_headers,
83
+ default_query=self.default_query,
84
+ )
85
+
86
+ @classmethod
87
+ def from_dict(cls, data: dict[str, Any]) -> Self:
88
+ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
89
+ return default_from_dict(cls, data)
90
+
91
+
92
+ @component
93
+ class OpenAITextGenerator(_BaseOpenAIGenerator):
94
+ """A component that generates text using the OpenAI API."""
95
+
96
+ @component.output_types(replies=list[str])
97
+ def run(
98
+ self,
99
+ prompt: str,
100
+ max_tokens: Optional[int] = None,
101
+ stop_at: Optional[Union[str, list[str]]] = None,
102
+ seed: Optional[int] = None,
103
+ ) -> dict[str, list[str]]:
104
+ """Run the generation component based on a prompt.
105
+
106
+ Args:
107
+ prompt: The prompt to use for generation.
108
+ max_tokens: The maximum number of tokens to generate.
109
+ stop_at: A string or list of strings after which to stop generation.
110
+ seed: The seed to use for generation.
111
+ """
112
+ if not prompt:
113
+ return {"replies": []}
114
+
115
+ if seed is not None:
116
+ self.model.config.seed = seed
117
+
118
+ generate_text_func = generate.text(self.model)
119
+ answer = generate_text_func(prompts=prompt, max_tokens=max_tokens, stop_at=stop_at)
120
+ return {"replies": [answer]}
@@ -0,0 +1,86 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ from typing import Any, Optional, Union
6
+
7
+ from haystack import component
8
+ from outlines import generate, models
9
+
10
+
11
+ class _BaseTransformersGenerator:
12
+ def __init__(
13
+ self,
14
+ model_name: str,
15
+ device: Union[str, None] = None,
16
+ model_kwargs: Union[dict[str, Any], None] = None,
17
+ tokenizer_kwargs: Union[dict[str, Any], None] = None,
18
+ ) -> None:
19
+ """Initialize the MLXLM generator component.
20
+
21
+ For more info, see https://dottxt-ai.github.io/outlines/latest/reference/models/mlxlm/#load-the-model
22
+
23
+ Args:
24
+ model_name: The name of the model as listed on Hugging Face's model page.
25
+ device: The device(s) on which the model should be loaded. This overrides the `device_map` entry in
26
+ `model_kwargs` when provided.
27
+ model_kwargs: A dictionary that contains the keyword arguments to pass to the `from_pretrained` method
28
+ when loading the model.
29
+ tokenizer_kwargs: A dictionary that contains the keyword arguments to pass to the `from_pretrained` method
30
+ when loading the tokenizer.
31
+ """
32
+ self.model_name = model_name
33
+ self.device = device
34
+ self.model_kwargs = model_kwargs if model_kwargs is not None else {}
35
+ self.tokenizer_kwargs = tokenizer_kwargs if tokenizer_kwargs is not None else {}
36
+ self.model = None
37
+
38
+ @property
39
+ def _warmed_up(self) -> bool:
40
+ return self.model is not None
41
+
42
+ def warm_up(self) -> None:
43
+ """Initializes the component."""
44
+ if self._warmed_up:
45
+ return
46
+ self.model = models.transformers(
47
+ model_name=self.model_name,
48
+ device=self.device,
49
+ model_kwargs=self.model_kwargs,
50
+ tokenizer_kwargs=self.tokenizer_kwargs,
51
+ )
52
+
53
+ def _check_component_warmed_up(self) -> None:
54
+ if not self._warmed_up:
55
+ msg = f"The component {self.__class__.__name__} was not warmed up. Please call warm_up() before running."
56
+ raise RuntimeError(msg)
57
+
58
+
59
+ @component
60
+ class TransformersTextGenerator(_BaseTransformersGenerator):
61
+ """A component for generating text using a Transformers model."""
62
+
63
+ @component.output_types(replies=list[str])
64
+ def run(
65
+ self,
66
+ prompt: str,
67
+ max_tokens: Optional[int] = None,
68
+ stop_at: Optional[Union[str, list[str]]] = None,
69
+ seed: Optional[int] = None,
70
+ ) -> dict[str, list[str]]:
71
+ """Run the generation component based on a prompt.
72
+
73
+ Args:
74
+ prompt: The prompt to use for generation.
75
+ max_tokens: The maximum number of tokens to generate.
76
+ stop_at: A string or list of strings after which to stop generation.
77
+ seed: The seed to use for generation.
78
+ """
79
+ self._check_component_warmed_up()
80
+
81
+ if not prompt:
82
+ return {"replies": []}
83
+
84
+ generate_text_func = generate.text(self.model)
85
+ answer = generate_text_func(prompts=prompt, max_tokens=max_tokens, stop_at=stop_at, seed=seed)
86
+ return {"replies": [answer]}
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2024-present Edoardo Abati <29585319+EdAbati@users.noreply.github.com>
2
+ #
3
+ # SPDX-License-Identifier: MIT
@@ -0,0 +1,20 @@
1
+ from outlines_haystack.generators.azure_openai import AzureOpenAITextGenerator
2
+
3
+
4
+ def test_init_default(monkeypatch) -> None: # noqa: ANN001
5
+ monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
6
+ monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "test-endpoint")
7
+ monkeypatch.setenv("OPENAI_API_VERSION", "test-api-version")
8
+ component = AzureOpenAITextGenerator(model_name="gpt-4o-mini")
9
+ assert component.model_name == "gpt-4o-mini"
10
+ assert component.azure_endpoint == "test-endpoint"
11
+ assert component.azure_deployment is None
12
+ assert component.api_version is None
13
+ assert component.api_key.resolve_value() == "test-api-key"
14
+ assert component.azure_ad_token.resolve_value() is None
15
+ assert component.organization is None
16
+ assert component.project is None
17
+ assert component.timeout == 30
18
+ assert component.max_retries == 5
19
+ assert component.default_headers is None
20
+ assert component.default_query is None
@@ -0,0 +1,15 @@
1
+ from outlines_haystack.generators.openai import OpenAITextGenerator
2
+
3
+
4
+ def test_init_default(monkeypatch) -> None: # noqa: ANN001
5
+ monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
6
+ component = OpenAITextGenerator(model_name="gpt-4o-mini")
7
+ assert component.model_name == "gpt-4o-mini"
8
+ assert component.api_key.resolve_value() == "test-api-key"
9
+ assert component.organization is None
10
+ assert component.project is None
11
+ assert component.base_url is None
12
+ assert component.timeout == 30
13
+ assert component.max_retries == 5
14
+ assert component.default_headers is None
15
+ assert component.default_query is None
@@ -0,0 +1,10 @@
1
+ from outlines_haystack.generators.transformers import TransformersTextGenerator
2
+
3
+
4
+ def test_init_default() -> None:
5
+ component = TransformersTextGenerator(model_name="microsoft/Phi-3-mini-4k-instruct", device="cpu")
6
+ assert component.model_name == "microsoft/Phi-3-mini-4k-instruct"
7
+ assert component.device == "cpu"
8
+ assert component.model_kwargs == {}
9
+ assert component.tokenizer_kwargs == {}
10
+ assert component.model is None