gmtorch 0.0.0.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ _commit: v1.6.0
2
+ _src_path: gh:superlinear-ai/substrate
3
+ author_email: tomass.timmermans@student.kit.edu
4
+ author_name: Tomass Marks Timmermans
5
+ project_description: Probabilistic Generative AI Using Gaussian Mixture
6
+ project_name: gmtorch
7
+ project_type: package
8
+ project_url: https://gitlab.kit.edu/ali.darijani/gmtorch
9
+ python_version: "3.12"
10
+ typing: strict
11
+ with_conventional_commits: true
12
+ with_typer_cli: false
@@ -0,0 +1,80 @@
1
+ {
2
+ "name": "gmtorch",
3
+ "dockerComposeFile": "../docker-compose.yml",
4
+ "service": "devcontainer",
5
+ "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}/",
6
+ "features": {
7
+ "ghcr.io/devcontainers-extra/features/starship:1": {}
8
+ },
9
+ "overrideCommand": true,
10
+ "remoteUser": "user",
11
+ "postStartCommand": "sudo chown -R user:user /opt/ && uv sync --python ${localEnv:PYTHON_VERSION:3.12} --resolution ${localEnv:RESOLUTION_STRATEGY:highest} --all-extras && pre-commit install --install-hooks",
12
+ "customizations": {
13
+ "jetbrains": {
14
+ "backend": "PyCharm",
15
+ "plugins": [
16
+ "com.github.copilot"
17
+ ]
18
+ },
19
+ "vscode": {
20
+ "extensions": [
21
+ "charliermarsh.ruff",
22
+ "GitHub.copilot",
23
+ "GitHub.copilot-chat",
24
+ "GitLab.gitlab-workflow",
25
+ "ms-azuretools.vscode-docker",
26
+ "ms-python.mypy-type-checker",
27
+ "ms-python.python",
28
+ "ms-toolsai.jupyter",
29
+ "ryanluker.vscode-coverage-gutters",
30
+ "tamasfe.even-better-toml",
31
+ "visualstudioexptteam.vscodeintellicode"
32
+ ],
33
+ "settings": {
34
+ "coverage-gutters.coverageFileNames": [
35
+ "reports/coverage.xml"
36
+ ],
37
+ "editor.codeActionsOnSave": {
38
+ "source.fixAll": "explicit",
39
+ "source.organizeImports": "explicit"
40
+ },
41
+ "editor.formatOnSave": true,
42
+ "[python]": {
43
+ "editor.defaultFormatter": "charliermarsh.ruff"
44
+ },
45
+ "[toml]": {
46
+ "editor.formatOnSave": false
47
+ },
48
+ "editor.rulers": [
49
+ 100
50
+ ],
51
+ "files.autoSave": "onFocusChange",
52
+ "github.copilot.chat.agent.enabled": true,
53
+ "github.copilot.chat.codesearch.enabled": true,
54
+ "github.copilot.chat.edits.enabled": true,
55
+ "github.copilot.nextEditSuggestions.enabled": true,
56
+ "jupyter.kernels.excludePythonEnvironments": [
57
+ "/usr/local/bin/python"
58
+ ],
59
+ "mypy-type-checker.importStrategy": "fromEnvironment",
60
+ "mypy-type-checker.preferDaemon": true,
61
+ "notebook.codeActionsOnSave": {
62
+ "notebook.source.fixAll": "explicit",
63
+ "notebook.source.organizeImports": "explicit"
64
+ },
65
+ "notebook.formatOnSave.enabled": true,
66
+ "python.defaultInterpreterPath": "/opt/venv/bin/python",
67
+ "python.terminal.activateEnvironment": false,
68
+ "python.testing.pytestEnabled": true,
69
+ "ruff.importStrategy": "fromEnvironment",
70
+ "ruff.logLevel": "warning",
71
+ "terminal.integrated.env.linux": {
72
+ "GIT_EDITOR": "code --wait"
73
+ },
74
+ "terminal.integrated.env.mac": {
75
+ "GIT_EDITOR": "code --wait"
76
+ }
77
+ }
78
+ }
79
+ }
80
+ }
@@ -0,0 +1,8 @@
1
+ # Caches
2
+ .*_cache/
3
+
4
+ # Git
5
+ .git/
6
+
7
+ # Python
8
+ .venv/
@@ -0,0 +1,72 @@
1
+ # Coverage.py
2
+ htmlcov/
3
+ reports/
4
+
5
+ # Copier
6
+ *.rej
7
+
8
+ # Data
9
+ *.csv*
10
+ *.dat*
11
+ *.pickle*
12
+ *.xls*
13
+ *.zip*
14
+ data/
15
+
16
+ # direnv
17
+ .envrc
18
+
19
+ # dotenv
20
+ .env
21
+
22
+ # Hypothesis
23
+ .hypothesis/
24
+
25
+ # Jupyter
26
+ *.ipynb
27
+ .ipynb_checkpoints/
28
+ notebooks/
29
+
30
+ # macOS
31
+ .DS_Store
32
+
33
+ # mise
34
+ mise.local.toml
35
+
36
+ # MkDocs
37
+ site/
38
+
39
+ # mypy
40
+ .dmypy.json
41
+ .mypy_cache/
42
+
43
+ # Node.js
44
+ node_modules/
45
+
46
+ # PyCharm
47
+ .idea/
48
+
49
+ # pyenv
50
+ .python-version
51
+
52
+ # pytest
53
+ .pytest_cache/
54
+
55
+ # Python
56
+ __pycache__/
57
+ *.egg-info/
58
+ *.py[cdo]
59
+ .venv/
60
+ dist/
61
+
62
+ # Ruff
63
+ .ruff_cache/
64
+
65
+ # Terraform
66
+ .terraform/
67
+
68
+ # uv
69
+ uv.lock
70
+
71
+ # VS Code
72
+ .vscode/
@@ -0,0 +1,338 @@
1
+ # You can override the included template(s) by including variable overrides
2
+ # SAST customization: https://docs.gitlab.com/ee/user/application_security/sast/#customizing-the-sast-settings
3
+ # Secret Detection customization: https://docs.gitlab.com/user/application_security/secret_detection/pipeline/configure
4
+ # Dependency Scanning customization: https://docs.gitlab.com/ee/user/application_security/dependency_scanning/#customizing-the-dependency-scanning-settings
5
+ # Container Scanning customization: https://docs.gitlab.com/ee/user/application_security/container_scanning/#customizing-the-container-scanning-settings
6
+ # Note that environment variables can be set in several places
7
+ # See https://docs.gitlab.com/ee/ci/variables/#cicd-variable-precedence
8
+ stages:
9
+ - secret-detection
10
+ - build
11
+ - test
12
+ - publish
13
+
14
+ include:
15
+ - template: Security/SAST.gitlab-ci.yml
16
+ - template: Security/Secret-Detection.gitlab-ci.yml
17
+
18
+ variables:
19
+ SECRET_DETECTION_ENABLED: "true"
20
+ GIT_STRATEGY: clone
21
+ GIT_CLEAN_FLAGS: -ffdx
22
+
23
+ # Set the tag to the runner you want to use (also works for individual jobs)
24
+ # KIT runners: https://docs.gitlab.kit.edu/en/gitlab_runner/
25
+ default:
26
+ tags:
27
+ - local
28
+
29
+ # SECRET DETECTION
30
+ secret_detection:
31
+ stage: secret-detection
32
+
33
+ # BUILD
34
+ build_image:
35
+ stage: build
36
+ image:
37
+ name: gcr.io/kaniko-project/executor:debug
38
+ entrypoint: [""]
39
+ variables:
40
+ DOCKER_CONFIG: "/kaniko/.docker"
41
+ IMAGE_TAG: "$DOCKERHUB_REPO:$CI_COMMIT_SHORT_SHA"
42
+ DOCKERFILE_PATH: "$CI_PROJECT_DIR/Dockerfile"
43
+ CONTEXT: "$CI_PROJECT_DIR"
44
+ CACHE_REPO: "$DOCKERHUB_REPO"
45
+ before_script:
46
+ - mkdir -p /kaniko/.docker
47
+ - |
48
+ if [ -z "${DOCKERHUB_USERNAME}" ] || [ -z "${DOCKERHUB_PASSWORD}" ]; then
49
+ echo "DOCKERHUB_USERNAME or DOCKERHUB_PASSWORD is not set. Please define them in CI/CD variables." >&2
50
+ exit 1
51
+ fi
52
+ cat > /kaniko/.docker/config.json <<EOF
53
+ {
54
+ "auths": {
55
+ "https://index.docker.io/v1/": {
56
+ "auth": "$(echo -n ${DOCKERHUB_USERNAME}:${DOCKERHUB_PASSWORD} | base64)"
57
+ }
58
+ }
59
+ }
60
+ EOF
61
+ script:
62
+ - |
63
+ echo "Building and pushing image: $IMAGE_TAG via Kaniko"
64
+ /kaniko/executor \
65
+ --context "$CONTEXT" \
66
+ --dockerfile "$DOCKERFILE_PATH" \
67
+ --destination "$IMAGE_TAG" \
68
+ --target base \
69
+ --cache=true \
70
+ --cache-repo "$CACHE_REPO" \
71
+ --snapshot-mode=redo \
72
+ --build-arg BUILDKIT_INLINE_CACHE=1 \
73
+ --skip-unused-stages \
74
+ --compressed-caching=false
75
+
76
+ build_project:
77
+ stage: build
78
+ image: python:3.12-alpine
79
+ before_script:
80
+ - apk add --no-cache git curl jq
81
+ - pip install build
82
+ script:
83
+ - |
84
+ set -e # Exit on error
85
+
86
+ BASE_VERSION=$(grep '^version = ' pyproject.toml | head -1 | cut -d'"' -f2)
87
+ PACKAGE_NAME=$(grep '^name = ' pyproject.toml | head -1 | cut -d'"' -f2)
88
+ BUILD_VERSION="$BASE_VERSION"
89
+
90
+ echo "Fetching existing versions from PyPI..."
91
+ PYPI_RESPONSE=$(curl -s "https://pypi.org/pypi/$PACKAGE_NAME/json" 2>/dev/null || echo '{"releases":{}}')
92
+
93
+ if echo "$PYPI_RESPONSE" | jq -e '.releases' > /dev/null 2>&1; then
94
+ PACKAGE_EXISTS=true
95
+ echo "✅ Package found on PyPI"
96
+ else
97
+ PACKAGE_EXISTS=false
98
+ echo "ℹ️ Package not yet published on PyPI"
99
+ fi
100
+ echo ""
101
+
102
+
103
+ if [ "$CI_COMMIT_BRANCH" = "main" ] || [ "$CI_MERGE_REQUEST_TARGET_BRANCH_NAME" = "main" ]; then
104
+ echo "Building stable version: $BUILD_VERSION"
105
+
106
+ if [ "$PACKAGE_EXISTS" = "true" ] && echo "$PYPI_RESPONSE" | jq -e ".releases.\"$BUILD_VERSION\"" > /dev/null 2>&1; then
107
+ echo "❌ ERROR: Version $BUILD_VERSION already exists on PyPI!"
108
+ echo ""
109
+ echo "Please bump the version before merging to main:"
110
+ echo "--> cz bump"
111
+ exit 1
112
+ fi
113
+
114
+ elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
115
+ echo "Checking existing dev versions on PyPI..."
116
+
117
+ if [ "$PACKAGE_EXISTS" = "true" ]; then
118
+ EXISTING_VERSIONS=$(echo "$PYPI_RESPONSE" | jq -r '.releases | keys[]' | grep "^${BASE_VERSION}\.dev" || echo "")
119
+ else
120
+ EXISTING_VERSIONS=""
121
+ fi
122
+
123
+ if [ -z "$EXISTING_VERSIONS" ]; then
124
+ DEV_NUMBER=0
125
+ echo "No existing dev versions found, starting with dev0"
126
+ else
127
+ echo "Existing dev versions:"
128
+ echo "$EXISTING_VERSIONS"
129
+
130
+ HIGHEST_DEV=$(echo "$EXISTING_VERSIONS" | sed "s/${BASE_VERSION}\.dev//" | sort -n | tail -1)
131
+ DEV_NUMBER=$((HIGHEST_DEV + 1))
132
+ echo "Highest existing: dev${HIGHEST_DEV}"
133
+ echo "Using: dev${DEV_NUMBER}"
134
+ fi
135
+
136
+ BUILD_VERSION="${BASE_VERSION}.dev${DEV_NUMBER}"
137
+ echo "Building development version: $BUILD_VERSION"
138
+
139
+ else
140
+ SHORT_SHA=$(git rev-parse --short=8 HEAD)
141
+ BUILD_VERSION="${BASE_VERSION}.dev0+${SHORT_SHA}"
142
+ echo "Building feature branch version: $BUILD_VERSION"
143
+ echo "⚠️ Note: This version should NOT be published"
144
+ fi
145
+
146
+ sed -i "s/^version = .*/version = \"$BUILD_VERSION\"/" pyproject.toml
147
+
148
+ echo "Updated pyproject.toml version:"
149
+ grep '^version = ' pyproject.toml
150
+ echo ""
151
+
152
+ echo "Building package..."
153
+ python -m build --sdist --wheel --outdir dist
154
+
155
+ echo "BUILD_VERSION=${BUILD_VERSION}" >> build.env
156
+ echo "PACKAGE_NAME=${PACKAGE_NAME}" >> build.env
157
+
158
+ echo ""
159
+ echo "✅ Successfully built version: $BUILD_VERSION"
160
+ echo ""
161
+
162
+ artifacts:
163
+ when: on_success
164
+ expire_in: 1 week
165
+ reports:
166
+ dotenv: build.env
167
+ paths:
168
+ - dist/*.whl
169
+ - dist/*.tar.gz
170
+
171
+ # TEST
172
+ sast:
173
+ stage: test
174
+
175
+ # All tests but cuda tests
176
+ unit_tests:
177
+ stage: test
178
+ image: python:3.12-slim
179
+ needs: ["build_project"]
180
+ before_script:
181
+ - pip install pytest pytest-xdist
182
+ - pip install dist/*.whl
183
+ script:
184
+ - pytest tests/ --ignore=tests/cuda/ -n auto --capture=tee-sys --full-trace -v -ra --junitxml=junit.xml
185
+ artifacts:
186
+ when: on_failure
187
+ expire_in: 1 week
188
+ reports:
189
+ junit: junit.xml
190
+ paths:
191
+ - junit.xml
192
+
193
+ # Tests run with GPU / CUDA
194
+ gpu_tests:
195
+ stage: test
196
+ image: $DOCKERHUB_REPO:$CI_COMMIT_SHORT_SHA
197
+ needs: ["build_image"]
198
+ # Install pytest-xdist and add -n auto for parallel execution if needed
199
+ before_script:
200
+ - uv pip install pytest
201
+ script:
202
+ - pytest tests/cuda/ -v --junitxml=junit-gpu.xml
203
+ artifacts:
204
+ when: on_failure
205
+ expire_in: 1 week
206
+ reports:
207
+ junit: junit-gpu.xml
208
+ paths:
209
+ - junit-gpu.xml
210
+
211
+ # PUBLISH
212
+ publish_artifacts:
213
+ stage: publish
214
+ image: python:3.12-alpine
215
+ needs:
216
+ - job: build_project
217
+ artifacts: true
218
+ - job: unit_tests
219
+ - job: gpu_tests
220
+ when: manual
221
+ before_script:
222
+ - pip install twine
223
+ script:
224
+ - |
225
+ echo "Publishing version $BUILD_VERSION to GitLab Package Registry"
226
+ echo "Package: $PACKAGE_NAME"
227
+ TWINE_PASSWORD=${CI_JOB_TOKEN} TWINE_USERNAME=gitlab-ci-token python -m twine upload --repository-url ${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/pypi dist/*.whl dist/*.tar.gz
228
+ rules:
229
+ - if: $CI_COMMIT_BRANCH == "main" || $CI_COMMIT_BRANCH == "develop"
230
+
231
+ publish_test_pypi:
232
+ stage: publish
233
+ image: python:3.12-alpine
234
+ needs:
235
+ - job: build_project
236
+ artifacts: true
237
+ - job: unit_tests
238
+ - job: gpu_tests
239
+ when: manual
240
+ before_script:
241
+ - pip install twine
242
+ script:
243
+ - |
244
+ echo "Publishing version $BUILD_VERSION to Test PyPI"
245
+ python -m twine upload --repository testpypi dist/*.whl dist/*.tar.gz
246
+ variables:
247
+ TWINE_USERNAME: __token__
248
+ TWINE_PASSWORD: ${TEST_PYPI_API_TOKEN}
249
+ TWINE_REPOSITORY_URL: https://test.pypi.org/legacy/
250
+ rules:
251
+ - if: $CI_COMMIT_BRANCH == "main" || $CI_COMMIT_BRANCH == "develop"
252
+
253
+ publish_pypi:
254
+ stage: publish
255
+ image: python:3.12-alpine
256
+ needs:
257
+ - job: build_project
258
+ artifacts: true
259
+ - job: unit_tests
260
+ - job: gpu_tests
261
+ when: manual
262
+ before_script:
263
+ - pip install twine
264
+ script:
265
+ - |
266
+ echo "Publishing version $BUILD_VERSION to PyPI"
267
+ python -m twine upload dist/*.whl dist/*.tar.gz
268
+ variables:
269
+ TWINE_USERNAME: __token__
270
+ TWINE_PASSWORD: ${PYPI_API_TOKEN}
271
+ rules:
272
+ - if: $CI_COMMIT_BRANCH == "main" || $CI_COMMIT_BRANCH == "develop"
273
+
274
+ publish_docker:
275
+ stage: publish
276
+ image:
277
+ name: gcr.io/kaniko-project/executor:debug
278
+ entrypoint: [""]
279
+ needs:
280
+ - job: build_project
281
+ artifacts: true
282
+ - job: build_image
283
+ - job: unit_tests
284
+ - job: gpu_tests
285
+ when: manual
286
+ variables:
287
+ DOCKER_CONFIG: "/kaniko/.docker"
288
+ DOCKERFILE_PATH: "$CI_PROJECT_DIR/Dockerfile"
289
+ CONTEXT: "$CI_PROJECT_DIR"
290
+ CACHE_REPO: "$DOCKERHUB_REPO"
291
+ before_script:
292
+ - mkdir -p /kaniko/.docker
293
+ - |
294
+ if [ -z "${DOCKERHUB_USERNAME}" ] || [ -z "${DOCKERHUB_PASSWORD}" ]; then
295
+ echo "DOCKERHUB_USERNAME or DOCKERHUB_PASSWORD is not set. Please define them in CI/CD variables." >&2
296
+ exit 1
297
+ fi
298
+ cat > /kaniko/.docker/config.json <<EOF
299
+ {
300
+ "auths": {
301
+ "https://index.docker.io/v1/": {
302
+ "auth": "$(echo -n ${DOCKERHUB_USERNAME}:${DOCKERHUB_PASSWORD} | base64)"
303
+ }
304
+ }
305
+ }
306
+ EOF
307
+ script:
308
+ - |
309
+ # Determine tags based on branch
310
+ if [ "$CI_COMMIT_BRANCH" = "main" ]; then
311
+ # Get version from pyproject.toml (stable version)
312
+ VERSION=$(grep '^version = ' pyproject.toml | cut -d'"' -f2)
313
+ echo "Publishing stable Docker image: stable"
314
+ /kaniko/executor \
315
+ --context "$CONTEXT" \
316
+ --dockerfile "$DOCKERFILE_PATH" \
317
+ --destination "$DOCKERHUB_REPO:stable" \
318
+ --cache=true \
319
+ --cache-repo "$CACHE_REPO" \
320
+ --compressed-caching=false \
321
+ --build-arg BUILDKIT_INLINE_CACHE=1 \
322
+ --skip-unused-stages \
323
+ --single-snapshot
324
+ elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
325
+ echo "Publishing development Docker image: develop"
326
+ /kaniko/executor \
327
+ --context "$CONTEXT" \
328
+ --dockerfile "$DOCKERFILE_PATH" \
329
+ --destination "$DOCKERHUB_REPO:develop" \
330
+ --cache=true \
331
+ --cache-repo "$CACHE_REPO" \
332
+ --compressed-caching=false \
333
+ --build-arg BUILDKIT_INLINE_CACHE=1 \
334
+ --skip-unused-stages \
335
+ --single-snapshot
336
+ fi
337
+ rules:
338
+ - if: $CI_COMMIT_BRANCH == "main" || $CI_COMMIT_BRANCH == "develop"
@@ -0,0 +1,73 @@
1
+ # https://pre-commit.com
2
+ default_install_hook_types: [commit-msg, pre-commit]
3
+ default_stages: [pre-commit, manual]
4
+ fail_fast: true
5
+ repos:
6
+ - repo: meta
7
+ hooks:
8
+ - id: check-useless-excludes
9
+ - repo: https://github.com/pre-commit/pygrep-hooks
10
+ rev: v1.10.0
11
+ hooks:
12
+ - id: python-check-mock-methods
13
+ - id: python-use-type-annotations
14
+ - id: rst-backticks
15
+ - id: rst-directive-colons
16
+ - id: rst-inline-touching-normal
17
+ - id: text-unicode-replacement-char
18
+ - repo: https://github.com/pre-commit/pre-commit-hooks
19
+ rev: v5.0.0
20
+ hooks:
21
+ - id: check-added-large-files
22
+ - id: check-ast
23
+ - id: check-builtin-literals
24
+ - id: check-case-conflict
25
+ - id: check-docstring-first
26
+ - id: check-illegal-windows-names
27
+ - id: check-json
28
+ - id: check-merge-conflict
29
+ - id: check-shebang-scripts-are-executable
30
+ - id: check-symlinks
31
+ - id: check-toml
32
+ - id: check-vcs-permalinks
33
+ - id: check-xml
34
+ - id: check-yaml
35
+ - id: debug-statements
36
+ - id: destroyed-symlinks
37
+ - id: detect-private-key
38
+ - id: end-of-file-fixer
39
+ types: [python]
40
+ - id: fix-byte-order-marker
41
+ - id: mixed-line-ending
42
+ - id: name-tests-test
43
+ args: [--pytest-test-first]
44
+ - id: trailing-whitespace
45
+ types: [python]
46
+ - repo: local
47
+ hooks:
48
+ - id: commitizen
49
+ name: commitizen
50
+ entry: cz check
51
+ args: [--commit-msg-file]
52
+ require_serial: true
53
+ language: system
54
+ stages: [commit-msg]
55
+ - id: ruff-check
56
+ name: ruff check
57
+ entry: ruff check
58
+ args: ["--force-exclude", "--extend-fixable=ERA001,F401,F841,T201,T203"]
59
+ require_serial: true
60
+ language: system
61
+ types_or: [python, pyi]
62
+ - id: ruff-format
63
+ name: ruff format
64
+ entry: ruff format
65
+ args: [--force-exclude]
66
+ require_serial: true
67
+ language: system
68
+ types_or: [python, pyi]
69
+ - id: mypy
70
+ name: mypy
71
+ entry: mypy
72
+ language: system
73
+ types: [python]
@@ -0,0 +1,56 @@
1
+ # syntax=docker/dockerfile:1
2
+
3
+ # Build stage: Create the image with dependencies
4
+ FROM nvidia/cuda:12.6.0-runtime-ubuntu24.04 AS base
5
+
6
+ ENV DEBIAN_FRONTEND=noninteractive
7
+
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ python3.12 \
10
+ python3.12-venv \
11
+ && rm -rf /var/lib/apt/lists/* \
12
+ && apt-get clean
13
+
14
+ # Install uv
15
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
16
+
17
+ # Set working directory
18
+ WORKDIR /app
19
+
20
+ # Setup virtual environment
21
+ ENV VIRTUAL_ENV=/opt/venv
22
+ ENV PATH=$VIRTUAL_ENV/bin:$PATH
23
+ ENV UV_PROJECT_ENVIRONMENT=$VIRTUAL_ENV
24
+
25
+ COPY pyproject.toml README.md ./
26
+ COPY src ./src
27
+ COPY tests ./tests
28
+
29
+ # Create venv and install dependencies
30
+ RUN uv venv "$VIRTUAL_ENV" && \
31
+ uv pip install --no-cache -e .
32
+
33
+ FROM base AS dev
34
+
35
+ # Install git, sudo, and create non-root user with passwordless sudo
36
+ RUN apt-get update && apt-get install -y --no-install-recommends \
37
+ git \
38
+ sudo \
39
+ && rm -rf /var/lib/apt/lists/* \
40
+ && apt-get clean \
41
+ && groupadd --gid 1000 user \
42
+ && useradd --create-home --no-log-init --gid 1000 --uid 1000 --shell /usr/bin/bash user \
43
+ && chown user:user /opt/venv \
44
+ && echo 'user ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/user && chmod 0440 /etc/sudoers.d/user
45
+
46
+ # Tell Git that the workspace is safe to avoid 'detected dubious ownership in repository' warnings.
47
+ RUN git config --system --add safe.directory '*'
48
+
49
+ USER user
50
+
51
+ # Configure the non-root user's shell.
52
+ RUN mkdir ~/.history/ && \
53
+ echo 'HISTFILE=~/.history/.bash_history' >> ~/.bashrc && \
54
+ echo 'bind "\e[A": history-search-backward' >> ~/.bashrc && \
55
+ echo 'bind "\e[B": history-search-forward' >> ~/.bashrc && \
56
+ echo 'eval "$(starship init bash)"' >> ~/.bashrc
@@ -0,0 +1,109 @@
1
+ Metadata-Version: 2.4
2
+ Name: gmtorch
3
+ Version: 0.0.0.dev0
4
+ Summary: Probabilistic Generative AI Using Gaussian Mixture
5
+ Project-URL: homepage, https://gitlab.kit.edu/ali.darijani/gmtorch
6
+ Project-URL: source, https://gitlab.kit.edu/ali.darijani/gmtorch
7
+ Project-URL: changelog, https://gitlab.kit.edu/ali.darijani/gmtorch/-/blob/main/CHANGELOG.md
8
+ Project-URL: releasenotes, https://gitlab.kit.edu/ali.darijani/gmtorch/-/releases
9
+ Project-URL: documentation, https://gitlab.kit.edu/ali.darijani/gmtorch
10
+ Project-URL: issues, https://gitlab.kit.edu/ali.darijani/gmtorch/-/issues
11
+ Author-email: Ali Darijani <ali.darijani@kit.edu>, Tomass Marks Timmermans <tomass.timmermans@student.kit.edu>
12
+ Requires-Python: <4.0,>=3.12
13
+ Requires-Dist: torch>=2.9.0
14
+ Requires-Dist: torchvision>=0.24.0
15
+ Description-Content-Type: text/markdown
16
+
17
+ [![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=data:image/svg%2bxml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAyNCAyNCI+PHBhdGggZmlsbD0iI2ZmZiIgZD0iTTE3IDE2VjdsLTYgNU0yIDlWOGwxLTFoMWw0IDMgOC04aDFsNCAyIDEgMXYxNGwtMSAxLTQgMmgtMWwtOC04LTQgM0gzbC0xLTF2LTFsMy0zIi8+PC9zdmc+)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://gitlab.kit.edu/ali.darijani/gmtorch)
18
+
19
+ # gmtorch
20
+
21
+ Probabilistic Generative AI Using Gaussian Mixture
22
+
23
+ ## Installing
24
+
25
+ To install this package, run:
26
+
27
+ ```sh
28
+ pip install gmtorch
29
+ ```
30
+
31
+ ## Using
32
+
33
+ Example usage:
34
+
35
+ ```python
36
+ import gmtorch
37
+
38
+ ...
39
+ ```
40
+
41
+ ## Contributing
42
+
43
+ <details>
44
+ <summary>Prerequisites</summary>
45
+
46
+ 1. [Generate an SSH key](https://docs.gitlab.com/ee/user/ssh.html#generate-an-ssh-key-pair) and [add the SSH key to your GitLab account](https://docs.gitlab.com/ee/user/ssh.html#add-an-ssh-key-to-your-gitlab-account).
47
+ 1. Configure SSH to automatically load your SSH keys:
48
+
49
+ ```sh
50
+ cat << EOF >> ~/.ssh/config
51
+
52
+ Host *
53
+ AddKeysToAgent yes
54
+ IgnoreUnknown UseKeychain
55
+ UseKeychain yes
56
+ ForwardAgent yes
57
+ EOF
58
+ ```
59
+
60
+ 1. [Install Docker Desktop](https://www.docker.com/get-started).
61
+ 1. [Install VS Code](https://code.visualstudio.com/) and [VS Code's Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). Alternatively, install [PyCharm](https://www.jetbrains.com/pycharm/download/).
62
+ 1. _Optional:_ install a [Nerd Font](https://www.nerdfonts.com/font-downloads) such as [FiraCode Nerd Font](https://github.com/ryanoasis/nerd-fonts/tree/master/patched-fonts/FiraCode) and [configure VS Code](https://github.com/tonsky/FiraCode/wiki/VS-Code-Instructions) or [PyCharm](https://github.com/tonsky/FiraCode/wiki/Intellij-products-instructions) to use it.
63
+
64
+ </details>
65
+
66
+ <details open>
67
+ <summary>Development environments</summary>
68
+
69
+ The following development environments are supported:
70
+
71
+
72
+ 1. ⭐️ _VS Code Dev Container (with container volume)_: click on [Open in Dev Containers](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://gitlab.kit.edu/ali.darijani/gmtorch) to clone this repository in a container volume and create a Dev Container with VS Code.
73
+ 1. ⭐️ _uv_: clone this repository and run the following from root of the repository:
74
+
75
+ ```sh
76
+ # Create and install a virtual environment
77
+ uv sync --python 3.12 --all-extras
78
+
79
+ # Activate the virtual environment
80
+ source .venv/bin/activate
81
+
82
+ # Install the pre-commit hooks
83
+ pre-commit install --install-hooks
84
+ ```
85
+
86
+ 1. _VS Code Dev Container_: clone this repository, open it with VS Code, and run <kbd>Ctrl/⌘</kbd> + <kbd>⇧</kbd> + <kbd>P</kbd> → _Dev Containers: Reopen in Container_.
87
+ 1. _PyCharm Dev Container_: clone this repository, open it with PyCharm, [create a Dev Container with Mount Sources](https://www.jetbrains.com/help/pycharm/start-dev-container-inside-ide.html), and [configure an existing Python interpreter](https://www.jetbrains.com/help/pycharm/configuring-python-interpreter.html#widget) at `/opt/venv/bin/python`.
88
+
89
+ </details>
90
+
91
+ <details open>
92
+ <summary>Developing</summary>
93
+
94
+ - This project follows the [Conventional Commits](https://www.conventionalcommits.org/) standard to automate [Semantic Versioning](https://semver.org/) and [Keep A Changelog](https://keepachangelog.com/) with [Commitizen](https://github.com/commitizen-tools/commitizen).
95
+ - Run `poe` from within the development environment to print a list of [Poe the Poet](https://github.com/nat-n/poethepoet) tasks available to run on this project.
96
+ - Run `uv add {package}` from within the development environment to install a run time dependency and add it to `pyproject.toml` and `uv.lock`. Add `--dev` to install a development dependency.
97
+ - Run `uv sync --upgrade` from within the development environment to upgrade all dependencies to the latest versions allowed by `pyproject.toml`. Add `--only-dev` to upgrade the development dependencies only.
98
+ - Run `cz bump` to bump the package's version, update the `CHANGELOG.md`, and create a git tag. Then push the changes and the git tag with `git push origin main --tags`.
99
+ - Workflow: create a feature branch from `develop` -> create a merge request to `develop` -> on major releases, create a merge request from `develop` to `main`.
100
+
101
+ </details>
102
+
103
+ <details open>
104
+ <summary>Versioning</summary>
105
+
106
+ - Stable releases on `main` branch use [Semantic Versioning](https://semver.org/).
107
+ - Development releases on `develop` branch use the current version with .devN suffix, where N is the number of published development releases so far for the current version.
108
+ - Docker images use either `develop` or `stable` tags for publishing from `develop` or `main` branches, respectively.
109
+ </details>
@@ -0,0 +1,93 @@
1
+ [![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=data:image/svg%2bxml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAyNCAyNCI+PHBhdGggZmlsbD0iI2ZmZiIgZD0iTTE3IDE2VjdsLTYgNU0yIDlWOGwxLTFoMWw0IDMgOC04aDFsNCAyIDEgMXYxNGwtMSAxLTQgMmgtMWwtOC04LTQgM0gzbC0xLTF2LTFsMy0zIi8+PC9zdmc+)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://gitlab.kit.edu/ali.darijani/gmtorch)
2
+
3
+ # gmtorch
4
+
5
+ Probabilistic Generative AI Using Gaussian Mixture
6
+
7
+ ## Installing
8
+
9
+ To install this package, run:
10
+
11
+ ```sh
12
+ pip install gmtorch
13
+ ```
14
+
15
+ ## Using
16
+
17
+ Example usage:
18
+
19
+ ```python
20
+ import gmtorch
21
+
22
+ ...
23
+ ```
24
+
25
+ ## Contributing
26
+
27
+ <details>
28
+ <summary>Prerequisites</summary>
29
+
30
+ 1. [Generate an SSH key](https://docs.gitlab.com/ee/user/ssh.html#generate-an-ssh-key-pair) and [add the SSH key to your GitLab account](https://docs.gitlab.com/ee/user/ssh.html#add-an-ssh-key-to-your-gitlab-account).
31
+ 1. Configure SSH to automatically load your SSH keys:
32
+
33
+ ```sh
34
+ cat << EOF >> ~/.ssh/config
35
+
36
+ Host *
37
+ AddKeysToAgent yes
38
+ IgnoreUnknown UseKeychain
39
+ UseKeychain yes
40
+ ForwardAgent yes
41
+ EOF
42
+ ```
43
+
44
+ 1. [Install Docker Desktop](https://www.docker.com/get-started).
45
+ 1. [Install VS Code](https://code.visualstudio.com/) and [VS Code's Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). Alternatively, install [PyCharm](https://www.jetbrains.com/pycharm/download/).
46
+ 1. _Optional:_ install a [Nerd Font](https://www.nerdfonts.com/font-downloads) such as [FiraCode Nerd Font](https://github.com/ryanoasis/nerd-fonts/tree/master/patched-fonts/FiraCode) and [configure VS Code](https://github.com/tonsky/FiraCode/wiki/VS-Code-Instructions) or [PyCharm](https://github.com/tonsky/FiraCode/wiki/Intellij-products-instructions) to use it.
47
+
48
+ </details>
49
+
50
+ <details open>
51
+ <summary>Development environments</summary>
52
+
53
+ The following development environments are supported:
54
+
55
+
56
+ 1. ⭐️ _VS Code Dev Container (with container volume)_: click on [Open in Dev Containers](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://gitlab.kit.edu/ali.darijani/gmtorch) to clone this repository in a container volume and create a Dev Container with VS Code.
57
+ 1. ⭐️ _uv_: clone this repository and run the following from root of the repository:
58
+
59
+ ```sh
60
+ # Create and install a virtual environment
61
+ uv sync --python 3.12 --all-extras
62
+
63
+ # Activate the virtual environment
64
+ source .venv/bin/activate
65
+
66
+ # Install the pre-commit hooks
67
+ pre-commit install --install-hooks
68
+ ```
69
+
70
+ 1. _VS Code Dev Container_: clone this repository, open it with VS Code, and run <kbd>Ctrl/⌘</kbd> + <kbd>⇧</kbd> + <kbd>P</kbd> → _Dev Containers: Reopen in Container_.
71
+ 1. _PyCharm Dev Container_: clone this repository, open it with PyCharm, [create a Dev Container with Mount Sources](https://www.jetbrains.com/help/pycharm/start-dev-container-inside-ide.html), and [configure an existing Python interpreter](https://www.jetbrains.com/help/pycharm/configuring-python-interpreter.html#widget) at `/opt/venv/bin/python`.
72
+
73
+ </details>
74
+
75
+ <details open>
76
+ <summary>Developing</summary>
77
+
78
+ - This project follows the [Conventional Commits](https://www.conventionalcommits.org/) standard to automate [Semantic Versioning](https://semver.org/) and [Keep A Changelog](https://keepachangelog.com/) with [Commitizen](https://github.com/commitizen-tools/commitizen).
79
+ - Run `poe` from within the development environment to print a list of [Poe the Poet](https://github.com/nat-n/poethepoet) tasks available to run on this project.
80
+ - Run `uv add {package}` from within the development environment to install a run time dependency and add it to `pyproject.toml` and `uv.lock`. Add `--dev` to install a development dependency.
81
+ - Run `uv sync --upgrade` from within the development environment to upgrade all dependencies to the latest versions allowed by `pyproject.toml`. Add `--only-dev` to upgrade the development dependencies only.
82
+ - Run `cz bump` to bump the package's version, update the `CHANGELOG.md`, and create a git tag. Then push the changes and the git tag with `git push origin main --tags`.
83
+ - Workflow: create a feature branch from `develop` -> create a merge request to `develop` -> on major releases, create a merge request from `develop` to `main`.
84
+
85
+ </details>
86
+
87
+ <details open>
88
+ <summary>Versioning</summary>
89
+
90
+ - Stable releases on `main` branch use [Semantic Versioning](https://semver.org/).
91
+ - Development releases on `develop` branch use the current version with .devN suffix, where N is the number of published development releases so far for the current version.
92
+ - Docker images use either `develop` or `stable` tags for publishing from `develop` or `main` branches, respectively.
93
+ </details>
@@ -0,0 +1,19 @@
1
+ services:
2
+
3
+ devcontainer:
4
+ build:
5
+ target: dev
6
+ gpus: all
7
+ volumes:
8
+ - ..:/workspaces
9
+ - command-history-volume:/home/user/.history/
10
+ deploy:
11
+ resources:
12
+ reservations:
13
+ devices:
14
+ - driver: nvidia
15
+ count: all
16
+ capabilities: [ gpu ]
17
+
18
+ volumes:
19
+ command-history-volume:
@@ -0,0 +1,3 @@
1
+ # gmtorch
2
+
3
+ Probabilistic Generative AI Using Gaussian Mixture
@@ -0,0 +1,3 @@
1
+ # Reference
2
+
3
+ ::: gmtorch
@@ -0,0 +1,28 @@
1
+ {
2
+ "version": "15.2.2",
3
+ "vulnerabilities": [],
4
+ "scan": {
5
+ "analyzer": {
6
+ "id": "secrets",
7
+ "name": "secrets",
8
+ "url": "https://gitlab.com/gitlab-org/security-products/analyzers/secrets",
9
+ "vendor": {
10
+ "name": "GitLab"
11
+ },
12
+ "version": "7.18.0"
13
+ },
14
+ "scanner": {
15
+ "id": "gitleaks",
16
+ "name": "Gitleaks",
17
+ "url": "https://github.com/gitleaks/gitleaks",
18
+ "vendor": {
19
+ "name": "GitLab"
20
+ },
21
+ "version": "8.28.0"
22
+ },
23
+ "type": "secret_detection",
24
+ "start_time": "2025-10-28T15:59:43",
25
+ "end_time": "2025-10-28T15:59:44",
26
+ "status": "success"
27
+ }
28
+ }
@@ -0,0 +1,35 @@
1
+ site_name: gmtorch
2
+ site_description: Probabilistic Generative AI Using Gaussian Mixture
3
+ site_url: https://gitlab.kit.edu/ali.darijani/gmtorch
4
+ repo_url: https://gitlab.kit.edu/ali.darijani/gmtorch
5
+ repo_name: ali.darijani/gmtorch
6
+
7
+ strict: true
8
+
9
+ validation:
10
+ omitted_files: warn
11
+ absolute_links: warn
12
+ unrecognized_links: warn
13
+ anchors: warn
14
+
15
+ theme:
16
+ name: material
17
+
18
+ nav:
19
+ - Home: index.md
20
+ - Reference: reference.md
21
+
22
+ watch:
23
+ - src
24
+
25
+ plugins:
26
+ - search
27
+ - mkdocstrings:
28
+ handlers:
29
+ python:
30
+ options:
31
+ docstring_style: numpy
32
+
33
+ markdown_extensions:
34
+ - pymdownx.highlight
35
+ - pymdownx.superfences
@@ -0,0 +1,167 @@
1
+ [build-system] # https://docs.astral.sh/uv/concepts/projects/config/#build-systems
2
+ requires = ["hatchling>=1.27.0"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project] # https://packaging.python.org/en/latest/specifications/pyproject-toml/
6
+ name = "gmtorch"
7
+ version = "0.0.0.dev0"
8
+ description = "Probabilistic Generative AI Using Gaussian Mixture"
9
+ readme = "README.md"
10
+ authors = [
11
+ { name = "Ali Darijani", email = "ali.darijani@kit.edu" },
12
+ { name = "Tomass Marks Timmermans", email = "tomass.timmermans@student.kit.edu" }
13
+
14
+ ]
15
+ requires-python = ">=3.12,<4.0"
16
+ dependencies = [
17
+ "torch>=2.9.0",
18
+ "torchvision>=0.24.0",
19
+ ]
20
+
21
+ [tool.uv.sources]
22
+ torchvision = [
23
+ { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
24
+ ]
25
+
26
+ [[tool.uv.index]]
27
+ name = "pytorch-cu126"
28
+ url = "https://download.pytorch.org/whl/cu126"
29
+ explicit = true
30
+
31
+ [project.urls] # https://packaging.python.org/en/latest/specifications/well-known-project-urls/#well-known-labels
32
+ homepage = "https://gitlab.kit.edu/ali.darijani/gmtorch"
33
+ source = "https://gitlab.kit.edu/ali.darijani/gmtorch"
34
+ changelog = "https://gitlab.kit.edu/ali.darijani/gmtorch/-/blob/main/CHANGELOG.md"
35
+ releasenotes = "https://gitlab.kit.edu/ali.darijani/gmtorch/-/releases"
36
+ documentation = "https://gitlab.kit.edu/ali.darijani/gmtorch"
37
+ issues = "https://gitlab.kit.edu/ali.darijani/gmtorch/-/issues"
38
+
39
+ [dependency-groups] # https://docs.astral.sh/uv/concepts/projects/dependencies/#development-dependencies
40
+ dev = [
41
+ "commitizen (>=4.3.0)",
42
+ "coverage[toml] (>=7.6.10)",
43
+ "ipykernel (>=6.29.4)",
44
+ "ipython (>=8.18.0)",
45
+ "ipywidgets (>=8.1.2)",
46
+ "mypy (>=1.14.1)",
47
+ "mkdocs-material (>=9.5.21)",
48
+ "mkdocstrings[python] (>=0.26.2)",
49
+ "poethepoet (>=0.32.1)",
50
+ "pre-commit (>=4.0.1)",
51
+ "pytest (>=8.3.4)",
52
+ "pytest-mock (>=3.14.0)",
53
+ "pytest-xdist (>=3.6.1)",
54
+ "ruff (>=0.9.2)",
55
+ "typeguard (>=4.4.1)",
56
+ ]
57
+
58
+ [tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/
59
+ bump_message = "bump: v$current_version → v$new_version"
60
+ tag_format = "v$version"
61
+ update_changelog_on_bump = true
62
+ version_provider = "uv"
63
+
64
+ [tool.coverage.report] # https://coverage.readthedocs.io/en/latest/config.html#report
65
+ fail_under = 50
66
+ precision = 1
67
+ show_missing = true
68
+ skip_covered = true
69
+
70
+ [tool.coverage.run] # https://coverage.readthedocs.io/en/latest/config.html#run
71
+ branch = true
72
+ command_line = "--module pytest"
73
+ data_file = "reports/.coverage"
74
+ source = ["src"]
75
+
76
+ [tool.coverage.xml] # https://coverage.readthedocs.io/en/latest/config.html#xml
77
+ output = "reports/coverage.xml"
78
+
79
+ [tool.mypy] # https://mypy.readthedocs.io/en/latest/config_file.html
80
+ junit_xml = "reports/mypy.xml"
81
+ strict = true
82
+ disallow_subclassing_any = false
83
+ disallow_untyped_decorators = false
84
+ ignore_missing_imports = true
85
+ pretty = true
86
+ show_column_numbers = true
87
+ show_error_codes = true
88
+ show_error_context = true
89
+ warn_unreachable = true
90
+
91
+ [tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref
92
+ addopts = "--color=yes --doctest-modules --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml"
93
+ filterwarnings = ["error", "ignore::DeprecationWarning"]
94
+ testpaths = ["src", "tests"]
95
+ xfail_strict = true
96
+
97
+ [tool.ruff] # https://docs.astral.sh/ruff/settings/
98
+ fix = true
99
+ line-length = 100
100
+ src = ["src", "tests"]
101
+ target-version = "py312"
102
+
103
+ [tool.ruff.format]
104
+ docstring-code-format = true
105
+ skip-magic-trailing-comma = true
106
+
107
+ [tool.ruff.lint]
108
+ select = ["ALL"]
109
+ ignore = ["CPY", "FIX", "ARG001", "COM812", "D203", "D213", "E501", "PD008", "PD009", "RET504", "S101", "TD003"]
110
+ unfixable = ["ERA001", "F401", "F841", "T201", "T203"]
111
+
112
+ [tool.ruff.lint.flake8-annotations]
113
+ allow-star-arg-any = true
114
+
115
+ [tool.ruff.lint.flake8-tidy-imports]
116
+ ban-relative-imports = "all"
117
+
118
+ [tool.ruff.lint.isort]
119
+ split-on-trailing-comma = false
120
+
121
+ [tool.ruff.lint.pycodestyle]
122
+ max-doc-length = 100
123
+
124
+ [tool.ruff.lint.pydocstyle]
125
+ convention = "numpy"
126
+
127
+ [tool.poe.executor] # https://github.com/nat-n/poethepoet
128
+ type = "simple"
129
+
130
+ [tool.poe.tasks]
131
+
132
+ [tool.poe.tasks.docs]
133
+ help = "Build or serve the documentation"
134
+ shell = """
135
+ if [ $serve ]
136
+ then {
137
+ mkdocs serve
138
+ } else {
139
+ mkdocs build
140
+ } fi
141
+ """
142
+
143
+ [[tool.poe.tasks.docs.args]]
144
+ help = "Serve the documentation locally with live reload"
145
+ type = "boolean"
146
+ name = "serve"
147
+ options = ["--serve"]
148
+
149
+ [tool.poe.tasks.lint]
150
+ help = "Lint this package"
151
+ cmd = """
152
+ pre-commit run
153
+ --all-files
154
+ --color always
155
+ """
156
+
157
+ [tool.poe.tasks.test]
158
+ help = "Test this package"
159
+
160
+ [[tool.poe.tasks.test.sequence]]
161
+ cmd = "coverage run"
162
+
163
+ [[tool.poe.tasks.test.sequence]]
164
+ cmd = "coverage report"
165
+
166
+ [[tool.poe.tasks.test.sequence]]
167
+ cmd = "coverage xml"
@@ -0,0 +1,2 @@
1
+ torch>=2.0.0
2
+ pytest>=7.0.0
File without changes
@@ -0,0 +1 @@
1
+ """gmtorch."""
@@ -0,0 +1 @@
1
+ """gmtorch test suite."""
@@ -0,0 +1 @@
1
+ """CUDA tests for gmtorch."""
@@ -0,0 +1,23 @@
1
+ """Unit tests for CUDA functionality."""
2
+
3
+ import logging
4
+
5
+ import torch.cuda
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def test_cuda_availability() -> None:
11
+ """Test if CUDA is available."""
12
+ logger.info("Testing CUDA...")
13
+ # GIVEN / WHEN
14
+ cuda_available = torch.cuda.is_available()
15
+
16
+ # THEN
17
+ assert cuda_available is True
18
+
19
+ logger.info("CUDA available!")
20
+ logger.info("✓ CUDA version: %s", torch.version.cuda)
21
+ logger.info("✓ Number of GPUs: %d", torch.cuda.device_count())
22
+ for i in range(torch.cuda.device_count()):
23
+ logger.info(" - GPU %d: %s", i, torch.cuda.get_device_name(i))
@@ -0,0 +1,8 @@
1
+ """Test gmtorch."""
2
+
3
+ import gmtorch
4
+
5
+
6
+ def test_import() -> None:
7
+ """Test that the package can be imported."""
8
+ assert isinstance(gmtorch.__name__, str)