fastabx 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastabx-0.1.0/.clang-format +125 -0
- fastabx-0.1.0/.github/workflows/release.yml +61 -0
- fastabx-0.1.0/.gitignore +9 -0
- fastabx-0.1.0/.pre-commit-config.yaml +11 -0
- fastabx-0.1.0/.python-version +1 -0
- fastabx-0.1.0/LICENSE +21 -0
- fastabx-0.1.0/PKG-INFO +173 -0
- fastabx-0.1.0/README.md +137 -0
- fastabx-0.1.0/assets/gaussian.gif +0 -0
- fastabx-0.1.0/assets/gaussian.png +0 -0
- fastabx-0.1.0/pyproject.toml +66 -0
- fastabx-0.1.0/setup.cfg +4 -0
- fastabx-0.1.0/setup.py +36 -0
- fastabx-0.1.0/src/fastabx/__init__.py +12 -0
- fastabx-0.1.0/src/fastabx/__main__.py +39 -0
- fastabx-0.1.0/src/fastabx/cell.py +119 -0
- fastabx-0.1.0/src/fastabx/csrc/cuda/dtw.cu +160 -0
- fastabx-0.1.0/src/fastabx/csrc/dtw.cpp +107 -0
- fastabx-0.1.0/src/fastabx/dataset.py +201 -0
- fastabx-0.1.0/src/fastabx/distance.py +108 -0
- fastabx-0.1.0/src/fastabx/dtw.py +13 -0
- fastabx-0.1.0/src/fastabx/pooling.py +53 -0
- fastabx-0.1.0/src/fastabx/py.typed +0 -0
- fastabx-0.1.0/src/fastabx/score.py +79 -0
- fastabx-0.1.0/src/fastabx/subsample.py +59 -0
- fastabx-0.1.0/src/fastabx/task.py +77 -0
- fastabx-0.1.0/src/fastabx/utils.py +25 -0
- fastabx-0.1.0/src/fastabx/verify.py +104 -0
- fastabx-0.1.0/src/fastabx/zerospeech.py +48 -0
- fastabx-0.1.0/src/fastabx.egg-info/PKG-INFO +173 -0
- fastabx-0.1.0/src/fastabx.egg-info/SOURCES.txt +39 -0
- fastabx-0.1.0/src/fastabx.egg-info/dependency_links.txt +1 -0
- fastabx-0.1.0/src/fastabx.egg-info/entry_points.txt +2 -0
- fastabx-0.1.0/src/fastabx.egg-info/requires.txt +7 -0
- fastabx-0.1.0/src/fastabx.egg-info/top_level.txt +1 -0
- fastabx-0.1.0/tests/__init__.py +1 -0
- fastabx-0.1.0/tests/benchmark.py +74 -0
- fastabx-0.1.0/tests/conftest.py +30 -0
- fastabx-0.1.0/tests/test_dtw.py +60 -0
- fastabx-0.1.0/tests/test_zerospeech.py +56 -0
- fastabx-0.1.0/uv.lock +1545 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
---
|
|
2
|
+
AccessModifierOffset: -1
|
|
3
|
+
AlignAfterOpenBracket: AlwaysBreak
|
|
4
|
+
AlignConsecutiveAssignments: false
|
|
5
|
+
AlignConsecutiveDeclarations: false
|
|
6
|
+
AlignEscapedNewlinesLeft: true
|
|
7
|
+
AlignOperands: false
|
|
8
|
+
AlignTrailingComments: false
|
|
9
|
+
AllowAllParametersOfDeclarationOnNextLine: false
|
|
10
|
+
AllowShortBlocksOnASingleLine: false
|
|
11
|
+
AllowShortCaseLabelsOnASingleLine: false
|
|
12
|
+
AllowShortFunctionsOnASingleLine: Empty
|
|
13
|
+
AllowShortIfStatementsOnASingleLine: false
|
|
14
|
+
AllowShortLoopsOnASingleLine: false
|
|
15
|
+
AlwaysBreakAfterReturnType: None
|
|
16
|
+
AlwaysBreakBeforeMultilineStrings: true
|
|
17
|
+
AlwaysBreakTemplateDeclarations: true
|
|
18
|
+
BinPackArguments: false
|
|
19
|
+
BinPackParameters: false
|
|
20
|
+
BraceWrapping:
|
|
21
|
+
AfterClass: false
|
|
22
|
+
AfterControlStatement: false
|
|
23
|
+
AfterEnum: false
|
|
24
|
+
AfterFunction: false
|
|
25
|
+
AfterNamespace: false
|
|
26
|
+
AfterObjCDeclaration: false
|
|
27
|
+
AfterStruct: false
|
|
28
|
+
AfterUnion: false
|
|
29
|
+
BeforeCatch: false
|
|
30
|
+
BeforeElse: false
|
|
31
|
+
IndentBraces: false
|
|
32
|
+
BreakBeforeBinaryOperators: None
|
|
33
|
+
BreakBeforeBraces: Attach
|
|
34
|
+
BreakBeforeTernaryOperators: true
|
|
35
|
+
BreakConstructorInitializersBeforeComma: false
|
|
36
|
+
BreakAfterJavaFieldAnnotations: false
|
|
37
|
+
BreakStringLiterals: false
|
|
38
|
+
ColumnLimit: 119
|
|
39
|
+
CommentPragmas: "^ IWYU pragma:"
|
|
40
|
+
CompactNamespaces: false
|
|
41
|
+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
|
42
|
+
ConstructorInitializerIndentWidth: 4
|
|
43
|
+
ContinuationIndentWidth: 4
|
|
44
|
+
Cpp11BracedListStyle: true
|
|
45
|
+
DerivePointerAlignment: false
|
|
46
|
+
DisableFormat: false
|
|
47
|
+
ForEachMacros:
|
|
48
|
+
- FOR_EACH_RANGE
|
|
49
|
+
- FOR_EACH
|
|
50
|
+
IncludeCategories:
|
|
51
|
+
- Regex: '^<.*\.h(pp)?>'
|
|
52
|
+
Priority: 1
|
|
53
|
+
- Regex: "^<.*"
|
|
54
|
+
Priority: 2
|
|
55
|
+
- Regex: ".*"
|
|
56
|
+
Priority: 3
|
|
57
|
+
IndentCaseLabels: true
|
|
58
|
+
IndentWidth: 2
|
|
59
|
+
IndentWrappedFunctionNames: false
|
|
60
|
+
KeepEmptyLinesAtTheStartOfBlocks: false
|
|
61
|
+
MacroBlockBegin: ""
|
|
62
|
+
MacroBlockEnd: ""
|
|
63
|
+
Macros:
|
|
64
|
+
- >-
|
|
65
|
+
PyObject_HEAD_INIT(type)={
|
|
66
|
+
/* this is not exactly match with PyObject_HEAD_INIT in Python source code
|
|
67
|
+
* but it is enough for clang-format */
|
|
68
|
+
{ 0xFFFFFFFF },
|
|
69
|
+
(type)
|
|
70
|
+
},
|
|
71
|
+
- >-
|
|
72
|
+
PyVarObject_HEAD_INIT(type, size)={
|
|
73
|
+
{
|
|
74
|
+
/* manually expand PyObject_HEAD_INIT(type) above
|
|
75
|
+
* because clang-format do not support recursive expansion */
|
|
76
|
+
{ 0xFFFFFFFF },
|
|
77
|
+
(type)
|
|
78
|
+
},
|
|
79
|
+
(size)
|
|
80
|
+
},
|
|
81
|
+
MaxEmptyLinesToKeep: 1
|
|
82
|
+
NamespaceIndentation: None
|
|
83
|
+
PenaltyBreakBeforeFirstCallParameter: 1
|
|
84
|
+
PenaltyBreakComment: 300
|
|
85
|
+
PenaltyBreakFirstLessLess: 120
|
|
86
|
+
PenaltyBreakString: 1000
|
|
87
|
+
PenaltyExcessCharacter: 1000000
|
|
88
|
+
PenaltyReturnTypeOnItsOwnLine: 2000000
|
|
89
|
+
PointerAlignment: Left
|
|
90
|
+
ReflowComments: true
|
|
91
|
+
SortIncludes: true
|
|
92
|
+
SpaceAfterCStyleCast: false
|
|
93
|
+
SpaceBeforeAssignmentOperators: true
|
|
94
|
+
SpaceBeforeParens: ControlStatements
|
|
95
|
+
SpaceInEmptyParentheses: false
|
|
96
|
+
SpacesBeforeTrailingComments: 1
|
|
97
|
+
SpacesInAngles: false
|
|
98
|
+
SpacesInContainerLiterals: true
|
|
99
|
+
SpacesInCStyleCastParentheses: false
|
|
100
|
+
SpacesInParentheses: false
|
|
101
|
+
SpacesInSquareBrackets: false
|
|
102
|
+
Standard: c++17
|
|
103
|
+
StatementMacros:
|
|
104
|
+
- C10_DEFINE_bool
|
|
105
|
+
- C10_DEFINE_int
|
|
106
|
+
- C10_DEFINE_int32
|
|
107
|
+
- C10_DEFINE_int64
|
|
108
|
+
- C10_DEFINE_string
|
|
109
|
+
- C10_DEFINE_REGISTRY_WITHOUT_WARNING
|
|
110
|
+
- C10_REGISTER_CREATOR
|
|
111
|
+
- DEFINE_BINARY
|
|
112
|
+
- PyObject_HEAD
|
|
113
|
+
- PyObject_VAR_HEAD
|
|
114
|
+
- PyException_HEAD
|
|
115
|
+
- TORCH_DECLARE_bool
|
|
116
|
+
|
|
117
|
+
TabWidth: 8
|
|
118
|
+
UseTab: Never
|
|
119
|
+
---
|
|
120
|
+
Language: ObjC
|
|
121
|
+
ColumnLimit: 120
|
|
122
|
+
AlignAfterOpenBracket: Align
|
|
123
|
+
ObjCBlockIndentWidth: 2
|
|
124
|
+
ObjCSpaceAfterProperty: false
|
|
125
|
+
ObjCSpaceBeforeProtocolList: false
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
name: Release Workflow
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "*"
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
build:
|
|
10
|
+
strategy:
|
|
11
|
+
matrix:
|
|
12
|
+
os: [ubuntu-latest, macos-15]
|
|
13
|
+
runs-on: ${{ matrix.os }}
|
|
14
|
+
|
|
15
|
+
steps:
|
|
16
|
+
- name: Checkout code
|
|
17
|
+
uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Get tag name
|
|
20
|
+
id: get_tag
|
|
21
|
+
run: echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
|
22
|
+
|
|
23
|
+
- name: Install uv
|
|
24
|
+
uses: astral-sh/setup-uv@v5
|
|
25
|
+
with:
|
|
26
|
+
version: "latest"
|
|
27
|
+
|
|
28
|
+
- name: Build
|
|
29
|
+
run: uv build
|
|
30
|
+
|
|
31
|
+
- name: Upload artifacts
|
|
32
|
+
uses: actions/upload-artifact@v4
|
|
33
|
+
with:
|
|
34
|
+
name: dist-${{ matrix.os }}
|
|
35
|
+
path: ./dist/*
|
|
36
|
+
|
|
37
|
+
release:
|
|
38
|
+
needs: build
|
|
39
|
+
runs-on: ubuntu-latest
|
|
40
|
+
steps:
|
|
41
|
+
- name: Checkout code
|
|
42
|
+
uses: actions/checkout@v4
|
|
43
|
+
|
|
44
|
+
- name: Get tag name
|
|
45
|
+
id: get_tag
|
|
46
|
+
run: echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
|
47
|
+
|
|
48
|
+
- name: Download all artifacts
|
|
49
|
+
uses: actions/download-artifact@v4
|
|
50
|
+
with:
|
|
51
|
+
path: dist
|
|
52
|
+
|
|
53
|
+
- name: List artifacts
|
|
54
|
+
run:
|
|
55
|
+
find ./dist
|
|
56
|
+
|
|
57
|
+
- name: Create Release
|
|
58
|
+
env:
|
|
59
|
+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
60
|
+
TAG: ${{ steps.get_tag.outputs.TAG_NAME }}
|
|
61
|
+
run: gh release create "$TAG" ./dist/**/*.whl ./dist/dist-ubuntu-latest/abx-$TAG.tar.gz
|
fastabx-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.12
|
fastabx-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Maxime Poli.
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
fastabx-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: fastabx
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Fast ABX
|
|
5
|
+
Author: Maxime Poli
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Maxime Poli.
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Requires-Python: >=3.12
|
|
29
|
+
Description-Content-Type: text/markdown
|
|
30
|
+
Requires-Dist: numpy>=2.1.3
|
|
31
|
+
Requires-Dist: polars>=1.14.0
|
|
32
|
+
Requires-Dist: torch>=2.6.0
|
|
33
|
+
Requires-Dist: tqdm>=4.67.1
|
|
34
|
+
Provides-Extra: gpu
|
|
35
|
+
Requires-Dist: polars[gpu]>=1.14.0; extra == "gpu"
|
|
36
|
+
|
|
37
|
+
# Fast ABX
|
|
38
|
+
|
|
39
|
+
## Motivation
|
|
40
|
+
|
|
41
|
+
1. Simple and generic API
|
|
42
|
+
2. As fast as possible
|
|
43
|
+
|
|
44
|
+
This library aims to be as clear and minimal as possible to make its maintenance easy,
|
|
45
|
+
and the code readable and quick to understand. It should be easy to incorporate
|
|
46
|
+
different components into one's personal code, and not just use it as a black box.
|
|
47
|
+
|
|
48
|
+
At the same time, it must be as fast as possible to calculate the ABX, both in
|
|
49
|
+
forming triplets and calculating the distances themselves, while offering the
|
|
50
|
+
possibility to use any configuration of "on," "by," and "across" conditions.
|
|
51
|
+
|
|
52
|
+
The idea of creating yet again a new ABX library comes from the realization
|
|
53
|
+
that the [polars](https://github.com/pola-rs/polars) library efficiently and easily
|
|
54
|
+
solves the difficulties associated with creating triplets.
|
|
55
|
+
|
|
56
|
+
We can write the creation of the triplets as some "join" and "select" operations
|
|
57
|
+
on dataframes, then some "filter" for subsampling. With `polars`, the full query
|
|
58
|
+
is built lazily and then processed end-to-end. The backend will run several
|
|
59
|
+
optimizations for us, and can even run on GPU. We don't have to worry anymore
|
|
60
|
+
about how to built the triplets in a clever manner.
|
|
61
|
+
|
|
62
|
+
The computation of the distances is similar as
|
|
63
|
+
[libri-light-abx2](https://github.com/zerospeech/libri-light-abx2).
|
|
64
|
+
The important change is that now the DTW is computed in a PyTorch C++ extension,
|
|
65
|
+
with CPU (using OpenMP) and CUDA backends. The speedup is most noticeable on
|
|
66
|
+
large cells, such as those obtained when running the Phoneme ABX without
|
|
67
|
+
context conditions.
|
|
68
|
+
|
|
69
|
+
## Installation
|
|
70
|
+
|
|
71
|
+
### Pre-built package
|
|
72
|
+
|
|
73
|
+
To use the package on Macos or Linux x86-64, install it with `pip` in your environment with
|
|
74
|
+
Python >=3.12:
|
|
75
|
+
|
|
76
|
+
```bash
|
|
77
|
+
pip install fastabx
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
We will add support for other platforms soon.
|
|
81
|
+
|
|
82
|
+
### Build from source
|
|
83
|
+
|
|
84
|
+
Download this repository and run:
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
CXX="g++" uv build --wheel
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
If you want to have CUDA support, you must have a CUDA runtime installed,
|
|
91
|
+
and `CUDA_HOME` environment variable set.
|
|
92
|
+
|
|
93
|
+
## Examples
|
|
94
|
+
|
|
95
|
+
### ZeroSpeech ABX task
|
|
96
|
+
|
|
97
|
+
#### Python
|
|
98
|
+
|
|
99
|
+
```py
|
|
100
|
+
import torch
|
|
101
|
+
|
|
102
|
+
from fastabx import Dataset, Score, Subsampler, Task
|
|
103
|
+
|
|
104
|
+
dataset = Dataset.from_item("./triphone-dev-clean.item", "./features/dev-clean", 50, torch.load)
|
|
105
|
+
task = Task(dataset, on="#phone", by=["next-phone", "prev-phone", "speaker"], subsampler=Subsampler())
|
|
106
|
+
score = Score(task, "cosine")
|
|
107
|
+
print(score.collapse(levels=[("next-phone", "prev-phone"), "speaker"]))
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
#### CLI
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
❯ fastabx --help
|
|
114
|
+
usage: fastabx [-h] [--frequency FREQUENCY] [--speaker {within,across}] [--context {within,any}]
|
|
115
|
+
[--distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}]
|
|
116
|
+
[--max-size-group MAX_SIZE_GROUP] [--max-x-across MAX_X_ACROSS]
|
|
117
|
+
[--seed SEED]
|
|
118
|
+
item features
|
|
119
|
+
|
|
120
|
+
ZeroSpeech ABX
|
|
121
|
+
|
|
122
|
+
positional arguments:
|
|
123
|
+
item Path to the item file
|
|
124
|
+
features Path to the features directory
|
|
125
|
+
|
|
126
|
+
options:
|
|
127
|
+
-h, --help show this help message and exit
|
|
128
|
+
--frequency FREQUENCY
|
|
129
|
+
Feature frequency (in Hz)
|
|
130
|
+
--speaker {within,across}
|
|
131
|
+
--context {within,any}
|
|
132
|
+
--distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}
|
|
133
|
+
--max-size-group MAX_SIZE_GROUP
|
|
134
|
+
Maximum size of a cell
|
|
135
|
+
--max-x-across MAX_X_ACROSS
|
|
136
|
+
With 'across', maximum number of X given (A, B)
|
|
137
|
+
--seed SEED
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
### ABX between two gaussians
|
|
141
|
+
|
|
142
|
+
```python
|
|
143
|
+
import matplotlib.pyplot as plt
|
|
144
|
+
import numpy as np
|
|
145
|
+
|
|
146
|
+
from fastabx import Dataset, Score, Task
|
|
147
|
+
|
|
148
|
+
n = 100
|
|
149
|
+
diagonal_shift = 4
|
|
150
|
+
mean = np.zeros(2)
|
|
151
|
+
cov = np.array([[4, -2], [-2, 3]])
|
|
152
|
+
|
|
153
|
+
rng = np.random.default_rng(seed=0)
|
|
154
|
+
first = rng.multivariate_normal(mean, cov, n)
|
|
155
|
+
second = rng.multivariate_normal(mean + np.ones(2) * diagonal_shift, cov, n)
|
|
156
|
+
|
|
157
|
+
dataset = Dataset.from_numpy(np.vstack([first, second]), {"label": [0] * n + [1] * n})
|
|
158
|
+
task = Task(dataset, on="label")
|
|
159
|
+
score = Score(task, "euclidean")
|
|
160
|
+
|
|
161
|
+
plt.plot(*first.T, ".", alpha=0.5)
|
|
162
|
+
plt.plot(*second.T, ".", alpha=0.5)
|
|
163
|
+
plt.axis("equal")
|
|
164
|
+
plt.grid()
|
|
165
|
+
plt.title(f"ABX: {score.collapse():.3%}")
|
|
166
|
+
plt.show()
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
<img src="./assets/gaussian.png" width=70%>
|
|
170
|
+
|
|
171
|
+
#### ABX with increasing shift between the two gaussians
|
|
172
|
+
|
|
173
|
+

|
fastabx-0.1.0/README.md
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# Fast ABX
|
|
2
|
+
|
|
3
|
+
## Motivation
|
|
4
|
+
|
|
5
|
+
1. Simple and generic API
|
|
6
|
+
2. As fast as possible
|
|
7
|
+
|
|
8
|
+
This library aims to be as clear and minimal as possible to make its maintenance easy,
|
|
9
|
+
and the code readable and quick to understand. It should be easy to incorporate
|
|
10
|
+
different components into one's personal code, and not just use it as a black box.
|
|
11
|
+
|
|
12
|
+
At the same time, it must be as fast as possible to calculate the ABX, both in
|
|
13
|
+
forming triplets and calculating the distances themselves, while offering the
|
|
14
|
+
possibility to use any configuration of "on," "by," and "across" conditions.
|
|
15
|
+
|
|
16
|
+
The idea of creating yet again a new ABX library comes from the realization
|
|
17
|
+
that the [polars](https://github.com/pola-rs/polars) library efficiently and easily
|
|
18
|
+
solves the difficulties associated with creating triplets.
|
|
19
|
+
|
|
20
|
+
We can write the creation of the triplets as some "join" and "select" operations
|
|
21
|
+
on dataframes, then some "filter" for subsampling. With `polars`, the full query
|
|
22
|
+
is built lazily and then processed end-to-end. The backend will run several
|
|
23
|
+
optimizations for us, and can even run on GPU. We don't have to worry anymore
|
|
24
|
+
about how to built the triplets in a clever manner.
|
|
25
|
+
|
|
26
|
+
The computation of the distances is similar as
|
|
27
|
+
[libri-light-abx2](https://github.com/zerospeech/libri-light-abx2).
|
|
28
|
+
The important change is that now the DTW is computed in a PyTorch C++ extension,
|
|
29
|
+
with CPU (using OpenMP) and CUDA backends. The speedup is most noticeable on
|
|
30
|
+
large cells, such as those obtained when running the Phoneme ABX without
|
|
31
|
+
context conditions.
|
|
32
|
+
|
|
33
|
+
## Installation
|
|
34
|
+
|
|
35
|
+
### Pre-built package
|
|
36
|
+
|
|
37
|
+
To use the package on Macos or Linux x86-64, install it with `pip` in your environment with
|
|
38
|
+
Python >=3.12:
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
pip install fastabx
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
We will add support for other platforms soon.
|
|
45
|
+
|
|
46
|
+
### Build from source
|
|
47
|
+
|
|
48
|
+
Download this repository and run:
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
CXX="g++" uv build --wheel
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
If you want to have CUDA support, you must have a CUDA runtime installed,
|
|
55
|
+
and `CUDA_HOME` environment variable set.
|
|
56
|
+
|
|
57
|
+
## Examples
|
|
58
|
+
|
|
59
|
+
### ZeroSpeech ABX task
|
|
60
|
+
|
|
61
|
+
#### Python
|
|
62
|
+
|
|
63
|
+
```py
|
|
64
|
+
import torch
|
|
65
|
+
|
|
66
|
+
from fastabx import Dataset, Score, Subsampler, Task
|
|
67
|
+
|
|
68
|
+
dataset = Dataset.from_item("./triphone-dev-clean.item", "./features/dev-clean", 50, torch.load)
|
|
69
|
+
task = Task(dataset, on="#phone", by=["next-phone", "prev-phone", "speaker"], subsampler=Subsampler())
|
|
70
|
+
score = Score(task, "cosine")
|
|
71
|
+
print(score.collapse(levels=[("next-phone", "prev-phone"), "speaker"]))
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
#### CLI
|
|
75
|
+
|
|
76
|
+
```bash
|
|
77
|
+
❯ fastabx --help
|
|
78
|
+
usage: fastabx [-h] [--frequency FREQUENCY] [--speaker {within,across}] [--context {within,any}]
|
|
79
|
+
[--distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}]
|
|
80
|
+
[--max-size-group MAX_SIZE_GROUP] [--max-x-across MAX_X_ACROSS]
|
|
81
|
+
[--seed SEED]
|
|
82
|
+
item features
|
|
83
|
+
|
|
84
|
+
ZeroSpeech ABX
|
|
85
|
+
|
|
86
|
+
positional arguments:
|
|
87
|
+
item Path to the item file
|
|
88
|
+
features Path to the features directory
|
|
89
|
+
|
|
90
|
+
options:
|
|
91
|
+
-h, --help show this help message and exit
|
|
92
|
+
--frequency FREQUENCY
|
|
93
|
+
Feature frequency (in Hz)
|
|
94
|
+
--speaker {within,across}
|
|
95
|
+
--context {within,any}
|
|
96
|
+
--distance {euclidean,cosine,angular,kl,kl_symmetric,identical,null}
|
|
97
|
+
--max-size-group MAX_SIZE_GROUP
|
|
98
|
+
Maximum size of a cell
|
|
99
|
+
--max-x-across MAX_X_ACROSS
|
|
100
|
+
With 'across', maximum number of X given (A, B)
|
|
101
|
+
--seed SEED
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### ABX between two gaussians
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
import matplotlib.pyplot as plt
|
|
108
|
+
import numpy as np
|
|
109
|
+
|
|
110
|
+
from fastabx import Dataset, Score, Task
|
|
111
|
+
|
|
112
|
+
n = 100
|
|
113
|
+
diagonal_shift = 4
|
|
114
|
+
mean = np.zeros(2)
|
|
115
|
+
cov = np.array([[4, -2], [-2, 3]])
|
|
116
|
+
|
|
117
|
+
rng = np.random.default_rng(seed=0)
|
|
118
|
+
first = rng.multivariate_normal(mean, cov, n)
|
|
119
|
+
second = rng.multivariate_normal(mean + np.ones(2) * diagonal_shift, cov, n)
|
|
120
|
+
|
|
121
|
+
dataset = Dataset.from_numpy(np.vstack([first, second]), {"label": [0] * n + [1] * n})
|
|
122
|
+
task = Task(dataset, on="label")
|
|
123
|
+
score = Score(task, "euclidean")
|
|
124
|
+
|
|
125
|
+
plt.plot(*first.T, ".", alpha=0.5)
|
|
126
|
+
plt.plot(*second.T, ".", alpha=0.5)
|
|
127
|
+
plt.axis("equal")
|
|
128
|
+
plt.grid()
|
|
129
|
+
plt.title(f"ABX: {score.collapse():.3%}")
|
|
130
|
+
plt.show()
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
<img src="./assets/gaussian.png" width=70%>
|
|
134
|
+
|
|
135
|
+
#### ABX with increasing shift between the two gaussians
|
|
136
|
+
|
|
137
|
+

|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = [
|
|
3
|
+
"setuptools>=75.0.0",
|
|
4
|
+
"setuptools-scm>=8.1.0",
|
|
5
|
+
"torch>=2.6.0",
|
|
6
|
+
"numpy>=2.0.2",
|
|
7
|
+
"ninja>=1.11",
|
|
8
|
+
]
|
|
9
|
+
build-backend = "setuptools.build_meta"
|
|
10
|
+
|
|
11
|
+
[project]
|
|
12
|
+
name = "fastabx"
|
|
13
|
+
version = "0.1.0"
|
|
14
|
+
description = "Fast ABX"
|
|
15
|
+
readme = "README.md"
|
|
16
|
+
requires-python = ">=3.12"
|
|
17
|
+
authors = [ {name = "Maxime Poli"} ]
|
|
18
|
+
license = { file = "LICENSE" } # To update to PEP 639 once support is added to setuptools
|
|
19
|
+
dependencies = [
|
|
20
|
+
"numpy>=2.1.3",
|
|
21
|
+
"polars>=1.14.0",
|
|
22
|
+
"torch>=2.6.0",
|
|
23
|
+
"tqdm>=4.67.1",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[project.optional-dependencies]
|
|
27
|
+
gpu = [
|
|
28
|
+
"polars[gpu]>=1.14.0",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[project.scripts]
|
|
32
|
+
fastabx = "fastabx.__main__:main"
|
|
33
|
+
|
|
34
|
+
[dependency-groups]
|
|
35
|
+
dev = [
|
|
36
|
+
"ipdb>=0.13.13",
|
|
37
|
+
"ipykernel>=6.29.5",
|
|
38
|
+
"ipywidgets>=8.1.5",
|
|
39
|
+
"matplotlib>=3.10.0",
|
|
40
|
+
"mypy>=1.14.1",
|
|
41
|
+
"pre-commit>=4.0.1",
|
|
42
|
+
"ruff>=0.9.1",
|
|
43
|
+
"types-tqdm>=4.67.0.20241221",
|
|
44
|
+
"typos>=1.29.4",
|
|
45
|
+
]
|
|
46
|
+
test = [
|
|
47
|
+
"hypothesis>=6.125.1",
|
|
48
|
+
"pytest>=8.3.4",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
[tool.ruff]
|
|
52
|
+
line-length = 119
|
|
53
|
+
|
|
54
|
+
[tool.ruff.lint]
|
|
55
|
+
select = ["ALL"]
|
|
56
|
+
ignore = [
|
|
57
|
+
"COM812", # missing-trailing-comma
|
|
58
|
+
"D105", # undocumented-magic-method
|
|
59
|
+
"D107", # undocumented-public-init
|
|
60
|
+
"EM101", # raw-string-in-exception
|
|
61
|
+
"PD901", # pandas-df-variable-name
|
|
62
|
+
"TRY003", # raise-vanilla-args
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
[tool.setuptools]
|
|
66
|
+
license-files = [] # https://github.com/pypa/setuptools/issues/4759
|
fastabx-0.1.0/setup.cfg
ADDED
fastabx-0.1.0/setup.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Build the DTW PyTorch C++ extension."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from setuptools import Extension, setup
|
|
7
|
+
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension, CUDAExtension
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_extension() -> Extension:
|
|
11
|
+
"""Either CUDA or CPU extension."""
|
|
12
|
+
use_cuda = CUDA_HOME is not None
|
|
13
|
+
extension = CUDAExtension if use_cuda else CppExtension
|
|
14
|
+
openmp = ["-fopenmp"] if sys.platform == "linux" else []
|
|
15
|
+
extra_compile_args = {
|
|
16
|
+
"cxx": ["-fdiagnostics-color=always", "-DPy_LIMITED_API=0x030c0000", "-O3", *openmp],
|
|
17
|
+
"nvcc": ["-O3"],
|
|
18
|
+
}
|
|
19
|
+
sources = ["src/fastabx/csrc/dtw.cpp"]
|
|
20
|
+
if use_cuda:
|
|
21
|
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "Volta;Turing;Ampere;Ada;Hopper"
|
|
22
|
+
sources.append("src/fastabx/csrc/cuda/dtw.cu")
|
|
23
|
+
return extension(
|
|
24
|
+
"fastabx._C",
|
|
25
|
+
sources,
|
|
26
|
+
extra_compile_args=extra_compile_args,
|
|
27
|
+
extra_link_args=openmp,
|
|
28
|
+
py_limited_api=True,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
setup(
|
|
33
|
+
ext_modules=[get_extension()],
|
|
34
|
+
cmdclass={"build_ext": BuildExtension},
|
|
35
|
+
options={"bdist_wheel": {"py_limited_api": "cp312"}},
|
|
36
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Full ABX."""
|
|
2
|
+
|
|
3
|
+
from fastabx.dataset import Dataset
|
|
4
|
+
from fastabx.pooling import pooling
|
|
5
|
+
from fastabx.score import Score
|
|
6
|
+
from fastabx.subsample import Subsampler
|
|
7
|
+
from fastabx.task import Task
|
|
8
|
+
from fastabx.zerospeech import zerospeech_abx
|
|
9
|
+
|
|
10
|
+
from . import _C # type: ignore[attr-defined] # Load the PyTorch C++ extension.
|
|
11
|
+
|
|
12
|
+
__all__ = ["_C", "Dataset", "Score", "Subsampler", "Task", "pooling", "zerospeech_abx"]
|