fastabx 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fastabx-0.1.0/.clang-format +125 -0
  2. fastabx-0.1.0/.github/workflows/release.yml +61 -0
  3. fastabx-0.1.0/.gitignore +9 -0
  4. fastabx-0.1.0/.pre-commit-config.yaml +11 -0
  5. fastabx-0.1.0/.python-version +1 -0
  6. fastabx-0.1.0/LICENSE +21 -0
  7. fastabx-0.1.0/PKG-INFO +173 -0
  8. fastabx-0.1.0/README.md +137 -0
  9. fastabx-0.1.0/assets/gaussian.gif +0 -0
  10. fastabx-0.1.0/assets/gaussian.png +0 -0
  11. fastabx-0.1.0/pyproject.toml +66 -0
  12. fastabx-0.1.0/setup.cfg +4 -0
  13. fastabx-0.1.0/setup.py +36 -0
  14. fastabx-0.1.0/src/fastabx/__init__.py +12 -0
  15. fastabx-0.1.0/src/fastabx/__main__.py +39 -0
  16. fastabx-0.1.0/src/fastabx/cell.py +119 -0
  17. fastabx-0.1.0/src/fastabx/csrc/cuda/dtw.cu +160 -0
  18. fastabx-0.1.0/src/fastabx/csrc/dtw.cpp +107 -0
  19. fastabx-0.1.0/src/fastabx/dataset.py +201 -0
  20. fastabx-0.1.0/src/fastabx/distance.py +108 -0
  21. fastabx-0.1.0/src/fastabx/dtw.py +13 -0
  22. fastabx-0.1.0/src/fastabx/pooling.py +53 -0
  23. fastabx-0.1.0/src/fastabx/py.typed +0 -0
  24. fastabx-0.1.0/src/fastabx/score.py +79 -0
  25. fastabx-0.1.0/src/fastabx/subsample.py +59 -0
  26. fastabx-0.1.0/src/fastabx/task.py +77 -0
  27. fastabx-0.1.0/src/fastabx/utils.py +25 -0
  28. fastabx-0.1.0/src/fastabx/verify.py +104 -0
  29. fastabx-0.1.0/src/fastabx/zerospeech.py +48 -0
  30. fastabx-0.1.0/src/fastabx.egg-info/PKG-INFO +173 -0
  31. fastabx-0.1.0/src/fastabx.egg-info/SOURCES.txt +39 -0
  32. fastabx-0.1.0/src/fastabx.egg-info/dependency_links.txt +1 -0
  33. fastabx-0.1.0/src/fastabx.egg-info/entry_points.txt +2 -0
  34. fastabx-0.1.0/src/fastabx.egg-info/requires.txt +7 -0
  35. fastabx-0.1.0/src/fastabx.egg-info/top_level.txt +1 -0
  36. fastabx-0.1.0/tests/__init__.py +1 -0
  37. fastabx-0.1.0/tests/benchmark.py +74 -0
  38. fastabx-0.1.0/tests/conftest.py +30 -0
  39. fastabx-0.1.0/tests/test_dtw.py +60 -0
  40. fastabx-0.1.0/tests/test_zerospeech.py +56 -0
  41. fastabx-0.1.0/uv.lock +1545 -0
@@ -0,0 +1,125 @@
1
+ ---
2
+ AccessModifierOffset: -1
3
+ AlignAfterOpenBracket: AlwaysBreak
4
+ AlignConsecutiveAssignments: false
5
+ AlignConsecutiveDeclarations: false
6
+ AlignEscapedNewlinesLeft: true
7
+ AlignOperands: false
8
+ AlignTrailingComments: false
9
+ AllowAllParametersOfDeclarationOnNextLine: false
10
+ AllowShortBlocksOnASingleLine: false
11
+ AllowShortCaseLabelsOnASingleLine: false
12
+ AllowShortFunctionsOnASingleLine: Empty
13
+ AllowShortIfStatementsOnASingleLine: false
14
+ AllowShortLoopsOnASingleLine: false
15
+ AlwaysBreakAfterReturnType: None
16
+ AlwaysBreakBeforeMultilineStrings: true
17
+ AlwaysBreakTemplateDeclarations: true
18
+ BinPackArguments: false
19
+ BinPackParameters: false
20
+ BraceWrapping:
21
+ AfterClass: false
22
+ AfterControlStatement: false
23
+ AfterEnum: false
24
+ AfterFunction: false
25
+ AfterNamespace: false
26
+ AfterObjCDeclaration: false
27
+ AfterStruct: false
28
+ AfterUnion: false
29
+ BeforeCatch: false
30
+ BeforeElse: false
31
+ IndentBraces: false
32
+ BreakBeforeBinaryOperators: None
33
+ BreakBeforeBraces: Attach
34
+ BreakBeforeTernaryOperators: true
35
+ BreakConstructorInitializersBeforeComma: false
36
+ BreakAfterJavaFieldAnnotations: false
37
+ BreakStringLiterals: false
38
+ ColumnLimit: 119
39
+ CommentPragmas: "^ IWYU pragma:"
40
+ CompactNamespaces: false
41
+ ConstructorInitializerAllOnOneLineOrOnePerLine: true
42
+ ConstructorInitializerIndentWidth: 4
43
+ ContinuationIndentWidth: 4
44
+ Cpp11BracedListStyle: true
45
+ DerivePointerAlignment: false
46
+ DisableFormat: false
47
+ ForEachMacros:
48
+ - FOR_EACH_RANGE
49
+ - FOR_EACH
50
+ IncludeCategories:
51
+ - Regex: '^<.*\.h(pp)?>'
52
+ Priority: 1
53
+ - Regex: "^<.*"
54
+ Priority: 2
55
+ - Regex: ".*"
56
+ Priority: 3
57
+ IndentCaseLabels: true
58
+ IndentWidth: 2
59
+ IndentWrappedFunctionNames: false
60
+ KeepEmptyLinesAtTheStartOfBlocks: false
61
+ MacroBlockBegin: ""
62
+ MacroBlockEnd: ""
63
+ Macros:
64
+ - >-
65
+ PyObject_HEAD_INIT(type)={
66
+ /* this is not exactly match with PyObject_HEAD_INIT in Python source code
67
+ * but it is enough for clang-format */
68
+ { 0xFFFFFFFF },
69
+ (type)
70
+ },
71
+ - >-
72
+ PyVarObject_HEAD_INIT(type, size)={
73
+ {
74
+ /* manually expand PyObject_HEAD_INIT(type) above
75
+ * because clang-format do not support recursive expansion */
76
+ { 0xFFFFFFFF },
77
+ (type)
78
+ },
79
+ (size)
80
+ },
81
+ MaxEmptyLinesToKeep: 1
82
+ NamespaceIndentation: None
83
+ PenaltyBreakBeforeFirstCallParameter: 1
84
+ PenaltyBreakComment: 300
85
+ PenaltyBreakFirstLessLess: 120
86
+ PenaltyBreakString: 1000
87
+ PenaltyExcessCharacter: 1000000
88
+ PenaltyReturnTypeOnItsOwnLine: 2000000
89
+ PointerAlignment: Left
90
+ ReflowComments: true
91
+ SortIncludes: true
92
+ SpaceAfterCStyleCast: false
93
+ SpaceBeforeAssignmentOperators: true
94
+ SpaceBeforeParens: ControlStatements
95
+ SpaceInEmptyParentheses: false
96
+ SpacesBeforeTrailingComments: 1
97
+ SpacesInAngles: false
98
+ SpacesInContainerLiterals: true
99
+ SpacesInCStyleCastParentheses: false
100
+ SpacesInParentheses: false
101
+ SpacesInSquareBrackets: false
102
+ Standard: c++17
103
+ StatementMacros:
104
+ - C10_DEFINE_bool
105
+ - C10_DEFINE_int
106
+ - C10_DEFINE_int32
107
+ - C10_DEFINE_int64
108
+ - C10_DEFINE_string
109
+ - C10_DEFINE_REGISTRY_WITHOUT_WARNING
110
+ - C10_REGISTER_CREATOR
111
+ - DEFINE_BINARY
112
+ - PyObject_HEAD
113
+ - PyObject_VAR_HEAD
114
+ - PyException_HEAD
115
+ - TORCH_DECLARE_bool
116
+
117
+ TabWidth: 8
118
+ UseTab: Never
119
+ ---
120
+ Language: ObjC
121
+ ColumnLimit: 120
122
+ AlignAfterOpenBracket: Align
123
+ ObjCBlockIndentWidth: 2
124
+ ObjCSpaceAfterProperty: false
125
+ ObjCSpaceBeforeProtocolList: false
@@ -0,0 +1,61 @@
1
+ name: Release Workflow
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "*"
7
+
8
+ jobs:
9
+ build:
10
+ strategy:
11
+ matrix:
12
+ os: [ubuntu-latest, macos-15]
13
+ runs-on: ${{ matrix.os }}
14
+
15
+ steps:
16
+ - name: Checkout code
17
+ uses: actions/checkout@v4
18
+
19
+ - name: Get tag name
20
+ id: get_tag
21
+ run: echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
22
+
23
+ - name: Install uv
24
+ uses: astral-sh/setup-uv@v5
25
+ with:
26
+ version: "latest"
27
+
28
+ - name: Build
29
+ run: uv build
30
+
31
+ - name: Upload artifacts
32
+ uses: actions/upload-artifact@v4
33
+ with:
34
+ name: dist-${{ matrix.os }}
35
+ path: ./dist/*
36
+
37
+ release:
38
+ needs: build
39
+ runs-on: ubuntu-latest
40
+ steps:
41
+ - name: Checkout code
42
+ uses: actions/checkout@v4
43
+
44
+ - name: Get tag name
45
+ id: get_tag
46
+ run: echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
47
+
48
+ - name: Download all artifacts
49
+ uses: actions/download-artifact@v4
50
+ with:
51
+ path: dist
52
+
53
+ - name: List artifacts
54
+ run:
55
+ find ./dist
56
+
57
+ - name: Create Release
58
+ env:
59
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
60
+ TAG: ${{ steps.get_tag.outputs.TAG_NAME }}
61
+ run: gh release create "$TAG" ./dist/**/*.whl ./dist/dist-ubuntu-latest/abx-$TAG.tar.gz
@@ -0,0 +1,9 @@
1
+ .DS_Store
2
+ *.so
3
+ *.c
4
+ __pycache__
5
+ .vscode
6
+ *.egg-info
7
+ *.code-workspace
8
+ .envrc
9
+ build
@@ -0,0 +1,11 @@
1
+ repos:
2
+ - repo: https://github.com/crate-ci/typos
3
+ rev: v1.29.4
4
+ hooks:
5
+ - id: typos
6
+ - repo: https://github.com/astral-sh/ruff-pre-commit
7
+ rev: v0.9.1
8
+ hooks:
9
+ - id: ruff
10
+ args: [ --fix ]
11
+ - id: ruff-format
@@ -0,0 +1 @@
1
+ 3.12
fastabx-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Maxime Poli.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fastabx-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,173 @@
1
+ Metadata-Version: 2.2
2
+ Name: fastabx
3
+ Version: 0.1.0
4
+ Summary: Fast ABX
5
+ Author: Maxime Poli
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Maxime Poli.
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Requires-Python: >=3.12
29
+ Description-Content-Type: text/markdown
30
+ Requires-Dist: numpy>=2.1.3
31
+ Requires-Dist: polars>=1.14.0
32
+ Requires-Dist: torch>=2.6.0
33
+ Requires-Dist: tqdm>=4.67.1
34
+ Provides-Extra: gpu
35
+ Requires-Dist: polars[gpu]>=1.14.0; extra == "gpu"
36
+
37
+ # Fast ABX
38
+
39
+ ## Motivation
40
+
41
+ 1. Simple and generic API
42
+ 2. As fast as possible
43
+
44
+ This library aims to be as clear and minimal as possible to make its maintenance easy,
45
+ and the code readable and quick to understand. It should be easy to incorporate
46
+ different components into one's personal code, and not just use it as a black box.
47
+
48
+ At the same time, it must be as fast as possible to calculate the ABX, both in
49
+ forming triplets and calculating the distances themselves, while offering the
50
+ possibility to use any configuration of "on," "by," and "across" conditions.
51
+
52
+ The idea of creating yet again a new ABX library comes from the realization
53
+ that the [polars](https://github.com/pola-rs/polars) library efficiently and easily
54
+ solves the difficulties associated with creating triplets.
55
+
56
+ We can write the creation of the triplets as some "join" and "select" operations
57
+ on dataframes, then some "filter" for subsampling. With `polars`, the full query
58
+ is built lazily and then processed end-to-end. The backend will run several
59
+ optimizations for us, and can even run on GPU. We don't have to worry anymore
60
+ about how to built the triplets in a clever manner.
61
+
62
+ The computation of the distances is similar as
63
+ [libri-light-abx2](https://github.com/zerospeech/libri-light-abx2).
64
+ The important change is that now the DTW is computed in a PyTorch C++ extension,
65
+ with CPU (using OpenMP) and CUDA backends. The speedup is most noticeable on
66
+ large cells, such as those obtained when running the Phoneme ABX without
67
+ context conditions.
68
+
69
+ ## Installation
70
+
71
+ ### Pre-built package
72
+
73
+ To use the package on Macos or Linux x86-64, install it with `pip` in your environment with
74
+ Python >=3.12:
75
+
76
+ ```bash
77
+ pip install fastabx
78
+ ```
79
+
80
+ We will add support for other platforms soon.
81
+
82
+ ### Build from source
83
+
84
+ Download this repository and run:
85
+
86
+ ```bash
87
+ CXX="g++" uv build --wheel
88
+ ```
89
+
90
+ If you want to have CUDA support, you must have a CUDA runtime installed,
91
+ and `CUDA_HOME` environment variable set.
92
+
93
+ ## Examples
94
+
95
+ ### ZeroSpeech ABX task
96
+
97
+ #### Python
98
+
99
+ ```py
100
+ import torch
101
+
102
+ from fastabx import Dataset, Score, Subsampler, Task
103
+
104
+ dataset = Dataset.from_item("./triphone-dev-clean.item", "./features/dev-clean", 50, torch.load)
105
+ task = Task(dataset, on="#phone", by=["next-phone", "prev-phone", "speaker"], subsampler=Subsampler())
106
+ score = Score(task, "cosine")
107
+ print(score.collapse(levels=[("next-phone", "prev-phone"), "speaker"]))
108
+ ```
109
+
110
+ #### CLI
111
+
112
+ ```bash
113
+ ❯ fastabx --help
114
+ usage: fastabx [-h] [--frequency FREQUENCY] [--speaker {within,across}] [--context {within,any}]
115
+ [--distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}]
116
+ [--max-size-group MAX_SIZE_GROUP] [--max-x-across MAX_X_ACROSS]
117
+ [--seed SEED]
118
+ item features
119
+
120
+ ZeroSpeech ABX
121
+
122
+ positional arguments:
123
+ item Path to the item file
124
+ features Path to the features directory
125
+
126
+ options:
127
+ -h, --help show this help message and exit
128
+ --frequency FREQUENCY
129
+ Feature frequency (in Hz)
130
+ --speaker {within,across}
131
+ --context {within,any}
132
+ --distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}
133
+ --max-size-group MAX_SIZE_GROUP
134
+ Maximum size of a cell
135
+ --max-x-across MAX_X_ACROSS
136
+ With 'across', maximum number of X given (A, B)
137
+ --seed SEED
138
+ ```
139
+
140
+ ### ABX between two gaussians
141
+
142
+ ```python
143
+ import matplotlib.pyplot as plt
144
+ import numpy as np
145
+
146
+ from fastabx import Dataset, Score, Task
147
+
148
+ n = 100
149
+ diagonal_shift = 4
150
+ mean = np.zeros(2)
151
+ cov = np.array([[4, -2], [-2, 3]])
152
+
153
+ rng = np.random.default_rng(seed=0)
154
+ first = rng.multivariate_normal(mean, cov, n)
155
+ second = rng.multivariate_normal(mean + np.ones(2) * diagonal_shift, cov, n)
156
+
157
+ dataset = Dataset.from_numpy(np.vstack([first, second]), {"label": [0] * n + [1] * n})
158
+ task = Task(dataset, on="label")
159
+ score = Score(task, "euclidean")
160
+
161
+ plt.plot(*first.T, ".", alpha=0.5)
162
+ plt.plot(*second.T, ".", alpha=0.5)
163
+ plt.axis("equal")
164
+ plt.grid()
165
+ plt.title(f"ABX: {score.collapse():.3%}")
166
+ plt.show()
167
+ ```
168
+
169
+ <img src="./assets/gaussian.png" width=70%>
170
+
171
+ #### ABX with increasing shift between the two gaussians
172
+
173
+ ![gaussian_animation](./assets/gaussian.gif)
@@ -0,0 +1,137 @@
1
+ # Fast ABX
2
+
3
+ ## Motivation
4
+
5
+ 1. Simple and generic API
6
+ 2. As fast as possible
7
+
8
+ This library aims to be as clear and minimal as possible to make its maintenance easy,
9
+ and the code readable and quick to understand. It should be easy to incorporate
10
+ different components into one's personal code, and not just use it as a black box.
11
+
12
+ At the same time, it must be as fast as possible to calculate the ABX, both in
13
+ forming triplets and calculating the distances themselves, while offering the
14
+ possibility to use any configuration of "on," "by," and "across" conditions.
15
+
16
+ The idea of creating yet again a new ABX library comes from the realization
17
+ that the [polars](https://github.com/pola-rs/polars) library efficiently and easily
18
+ solves the difficulties associated with creating triplets.
19
+
20
+ We can write the creation of the triplets as some "join" and "select" operations
21
+ on dataframes, then some "filter" for subsampling. With `polars`, the full query
22
+ is built lazily and then processed end-to-end. The backend will run several
23
+ optimizations for us, and can even run on GPU. We don't have to worry anymore
24
+ about how to built the triplets in a clever manner.
25
+
26
+ The computation of the distances is similar as
27
+ [libri-light-abx2](https://github.com/zerospeech/libri-light-abx2).
28
+ The important change is that now the DTW is computed in a PyTorch C++ extension,
29
+ with CPU (using OpenMP) and CUDA backends. The speedup is most noticeable on
30
+ large cells, such as those obtained when running the Phoneme ABX without
31
+ context conditions.
32
+
33
+ ## Installation
34
+
35
+ ### Pre-built package
36
+
37
+ To use the package on Macos or Linux x86-64, install it with `pip` in your environment with
38
+ Python >=3.12:
39
+
40
+ ```bash
41
+ pip install fastabx
42
+ ```
43
+
44
+ We will add support for other platforms soon.
45
+
46
+ ### Build from source
47
+
48
+ Download this repository and run:
49
+
50
+ ```bash
51
+ CXX="g++" uv build --wheel
52
+ ```
53
+
54
+ If you want to have CUDA support, you must have a CUDA runtime installed,
55
+ and `CUDA_HOME` environment variable set.
56
+
57
+ ## Examples
58
+
59
+ ### ZeroSpeech ABX task
60
+
61
+ #### Python
62
+
63
+ ```py
64
+ import torch
65
+
66
+ from fastabx import Dataset, Score, Subsampler, Task
67
+
68
+ dataset = Dataset.from_item("./triphone-dev-clean.item", "./features/dev-clean", 50, torch.load)
69
+ task = Task(dataset, on="#phone", by=["next-phone", "prev-phone", "speaker"], subsampler=Subsampler())
70
+ score = Score(task, "cosine")
71
+ print(score.collapse(levels=[("next-phone", "prev-phone"), "speaker"]))
72
+ ```
73
+
74
+ #### CLI
75
+
76
+ ```bash
77
+ ❯ fastabx --help
78
+ usage: fastabx [-h] [--frequency FREQUENCY] [--speaker {within,across}] [--context {within,any}]
79
+ [--distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}]
80
+ [--max-size-group MAX_SIZE_GROUP] [--max-x-across MAX_X_ACROSS]
81
+ [--seed SEED]
82
+ item features
83
+
84
+ ZeroSpeech ABX
85
+
86
+ positional arguments:
87
+ item Path to the item file
88
+ features Path to the features directory
89
+
90
+ options:
91
+ -h, --help show this help message and exit
92
+ --frequency FREQUENCY
93
+ Feature frequency (in Hz)
94
+ --speaker {within,across}
95
+ --context {within,any}
96
+ --distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}
97
+ --max-size-group MAX_SIZE_GROUP
98
+ Maximum size of a cell
99
+ --max-x-across MAX_X_ACROSS
100
+ With 'across', maximum number of X given (A, B)
101
+ --seed SEED
102
+ ```
103
+
104
+ ### ABX between two gaussians
105
+
106
+ ```python
107
+ import matplotlib.pyplot as plt
108
+ import numpy as np
109
+
110
+ from fastabx import Dataset, Score, Task
111
+
112
+ n = 100
113
+ diagonal_shift = 4
114
+ mean = np.zeros(2)
115
+ cov = np.array([[4, -2], [-2, 3]])
116
+
117
+ rng = np.random.default_rng(seed=0)
118
+ first = rng.multivariate_normal(mean, cov, n)
119
+ second = rng.multivariate_normal(mean + np.ones(2) * diagonal_shift, cov, n)
120
+
121
+ dataset = Dataset.from_numpy(np.vstack([first, second]), {"label": [0] * n + [1] * n})
122
+ task = Task(dataset, on="label")
123
+ score = Score(task, "euclidean")
124
+
125
+ plt.plot(*first.T, ".", alpha=0.5)
126
+ plt.plot(*second.T, ".", alpha=0.5)
127
+ plt.axis("equal")
128
+ plt.grid()
129
+ plt.title(f"ABX: {score.collapse():.3%}")
130
+ plt.show()
131
+ ```
132
+
133
+ <img src="./assets/gaussian.png" width=70%>
134
+
135
+ #### ABX with increasing shift between the two gaussians
136
+
137
+ ![gaussian_animation](./assets/gaussian.gif)
Binary file
Binary file
@@ -0,0 +1,66 @@
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=75.0.0",
4
+ "setuptools-scm>=8.1.0",
5
+ "torch>=2.6.0",
6
+ "numpy>=2.0.2",
7
+ "ninja>=1.11",
8
+ ]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "fastabx"
13
+ version = "0.1.0"
14
+ description = "Fast ABX"
15
+ readme = "README.md"
16
+ requires-python = ">=3.12"
17
+ authors = [ {name = "Maxime Poli"} ]
18
+ license = { file = "LICENSE" } # To update to PEP 639 once support is added to setuptools
19
+ dependencies = [
20
+ "numpy>=2.1.3",
21
+ "polars>=1.14.0",
22
+ "torch>=2.6.0",
23
+ "tqdm>=4.67.1",
24
+ ]
25
+
26
+ [project.optional-dependencies]
27
+ gpu = [
28
+ "polars[gpu]>=1.14.0",
29
+ ]
30
+
31
+ [project.scripts]
32
+ fastabx = "fastabx.__main__:main"
33
+
34
+ [dependency-groups]
35
+ dev = [
36
+ "ipdb>=0.13.13",
37
+ "ipykernel>=6.29.5",
38
+ "ipywidgets>=8.1.5",
39
+ "matplotlib>=3.10.0",
40
+ "mypy>=1.14.1",
41
+ "pre-commit>=4.0.1",
42
+ "ruff>=0.9.1",
43
+ "types-tqdm>=4.67.0.20241221",
44
+ "typos>=1.29.4",
45
+ ]
46
+ test = [
47
+ "hypothesis>=6.125.1",
48
+ "pytest>=8.3.4",
49
+ ]
50
+
51
+ [tool.ruff]
52
+ line-length = 119
53
+
54
+ [tool.ruff.lint]
55
+ select = ["ALL"]
56
+ ignore = [
57
+ "COM812", # missing-trailing-comma
58
+ "D105", # undocumented-magic-method
59
+ "D107", # undocumented-public-init
60
+ "EM101", # raw-string-in-exception
61
+ "PD901", # pandas-df-variable-name
62
+ "TRY003", # raise-vanilla-args
63
+ ]
64
+
65
+ [tool.setuptools]
66
+ license-files = [] # https://github.com/pypa/setuptools/issues/4759
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
fastabx-0.1.0/setup.py ADDED
@@ -0,0 +1,36 @@
1
+ """Build the DTW PyTorch C++ extension."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ from setuptools import Extension, setup
7
+ from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension, CUDAExtension
8
+
9
+
10
+ def get_extension() -> Extension:
11
+ """Either CUDA or CPU extension."""
12
+ use_cuda = CUDA_HOME is not None
13
+ extension = CUDAExtension if use_cuda else CppExtension
14
+ openmp = ["-fopenmp"] if sys.platform == "linux" else []
15
+ extra_compile_args = {
16
+ "cxx": ["-fdiagnostics-color=always", "-DPy_LIMITED_API=0x030c0000", "-O3", *openmp],
17
+ "nvcc": ["-O3"],
18
+ }
19
+ sources = ["src/fastabx/csrc/dtw.cpp"]
20
+ if use_cuda:
21
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "Volta;Turing;Ampere;Ada;Hopper"
22
+ sources.append("src/fastabx/csrc/cuda/dtw.cu")
23
+ return extension(
24
+ "fastabx._C",
25
+ sources,
26
+ extra_compile_args=extra_compile_args,
27
+ extra_link_args=openmp,
28
+ py_limited_api=True,
29
+ )
30
+
31
+
32
+ setup(
33
+ ext_modules=[get_extension()],
34
+ cmdclass={"build_ext": BuildExtension},
35
+ options={"bdist_wheel": {"py_limited_api": "cp312"}},
36
+ )
@@ -0,0 +1,12 @@
1
+ """Full ABX."""
2
+
3
+ from fastabx.dataset import Dataset
4
+ from fastabx.pooling import pooling
5
+ from fastabx.score import Score
6
+ from fastabx.subsample import Subsampler
7
+ from fastabx.task import Task
8
+ from fastabx.zerospeech import zerospeech_abx
9
+
10
+ from . import _C # type: ignore[attr-defined] # Load the PyTorch C++ extension.
11
+
12
+ __all__ = ["_C", "Dataset", "Score", "Subsampler", "Task", "pooling", "zerospeech_abx"]