jax-image-models 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_image_models-0.3.3/.gitattributes +2 -0
- jax_image_models-0.3.3/.github/workflows/ci.yml +32 -0
- jax_image_models-0.3.3/.github/workflows/claude-code-review.yml +57 -0
- jax_image_models-0.3.3/.github/workflows/claude.yml +50 -0
- jax_image_models-0.3.3/.github/workflows/release.yaml +26 -0
- jax_image_models-0.3.3/.gitignore +205 -0
- jax_image_models-0.3.3/.pre-commit-config.yaml +27 -0
- jax_image_models-0.3.3/.python-version +1 -0
- jax_image_models-0.3.3/.vscode/settings.json +3 -0
- jax_image_models-0.3.3/LICENSE +21 -0
- jax_image_models-0.3.3/PKG-INFO +36 -0
- jax_image_models-0.3.3/README.md +23 -0
- jax_image_models-0.3.3/docs/index.md +74 -0
- jax_image_models-0.3.3/docs/models/CLIP.md +82 -0
- jax_image_models-0.3.3/docs/models/SigLIP.md +80 -0
- jax_image_models-0.3.3/docs/models/ViT.md +61 -0
- jax_image_models-0.3.3/examples/bench_fsdp.py +130 -0
- jax_image_models-0.3.3/examples/benchmark_clip_loss.py +102 -0
- jax_image_models-0.3.3/examples/clip_inference.py +85 -0
- jax_image_models-0.3.3/examples/clip_training_nnxjit.py +273 -0
- jax_image_models-0.3.3/examples/model_definitions.ipynb +168 -0
- jax_image_models-0.3.3/examples/saving_example.py +32 -0
- jax_image_models-0.3.3/examples/siglip_inference.py +67 -0
- jax_image_models-0.3.3/examples/vit_inference.py +74 -0
- jax_image_models-0.3.3/examples/vit_training.py +243 -0
- jax_image_models-0.3.3/images/test_image.jpg +0 -0
- jax_image_models-0.3.3/mkdocs.yml +28 -0
- jax_image_models-0.3.3/pyproject.toml +73 -0
- jax_image_models-0.3.3/ruff.toml +77 -0
- jax_image_models-0.3.3/src/jimm/__init__.py +33 -0
- jax_image_models-0.3.3/src/jimm/common/autotuning.py +150 -0
- jax_image_models-0.3.3/src/jimm/common/loading_utils.py +248 -0
- jax_image_models-0.3.3/src/jimm/common/sharding.py +98 -0
- jax_image_models-0.3.3/src/jimm/common/tokamax_attention.py +108 -0
- jax_image_models-0.3.3/src/jimm/common/transformer.py +299 -0
- jax_image_models-0.3.3/src/jimm/common/utils.py +157 -0
- jax_image_models-0.3.3/src/jimm/common/vit.py +340 -0
- jax_image_models-0.3.3/src/jimm/models/__init__.py +13 -0
- jax_image_models-0.3.3/src/jimm/models/clip/__init__.py +3 -0
- jax_image_models-0.3.3/src/jimm/models/clip/clip_model.py +635 -0
- jax_image_models-0.3.3/src/jimm/models/clip/params.py +520 -0
- jax_image_models-0.3.3/src/jimm/models/clip/sharding.py +57 -0
- jax_image_models-0.3.3/src/jimm/models/siglip/__init__.py +3 -0
- jax_image_models-0.3.3/src/jimm/models/siglip/params.py +608 -0
- jax_image_models-0.3.3/src/jimm/models/siglip/sharding.py +56 -0
- jax_image_models-0.3.3/src/jimm/models/siglip/siglip_model.py +614 -0
- jax_image_models-0.3.3/src/jimm/models/vit/__init__.py +3 -0
- jax_image_models-0.3.3/src/jimm/models/vit/params.py +272 -0
- jax_image_models-0.3.3/src/jimm/models/vit/sharding.py +55 -0
- jax_image_models-0.3.3/src/jimm/models/vit/vit_model.py +209 -0
- jax_image_models-0.3.3/tests/benchmark_utils.py +73 -0
- jax_image_models-0.3.3/tests/conftest.py +11 -0
- jax_image_models-0.3.3/tests/test_checkpointing.py +135 -0
- jax_image_models-0.3.3/tests/test_clip.py +387 -0
- jax_image_models-0.3.3/tests/test_loading_utils.py +92 -0
- jax_image_models-0.3.3/tests/test_siglip.py +335 -0
- jax_image_models-0.3.3/tests/test_transformer.py +60 -0
- jax_image_models-0.3.3/tests/test_vit.py +270 -0
- jax_image_models-0.3.3/uv.lock +4760 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
name: ci
|
|
2
|
+
on:
|
|
3
|
+
push:
|
|
4
|
+
branches:
|
|
5
|
+
- master
|
|
6
|
+
- main
|
|
7
|
+
permissions:
|
|
8
|
+
contents: write
|
|
9
|
+
jobs:
|
|
10
|
+
deploy:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
steps:
|
|
13
|
+
- uses: actions/checkout@v4
|
|
14
|
+
- name: Configure Git Credentials
|
|
15
|
+
run: |
|
|
16
|
+
git config user.name github-actions[bot]
|
|
17
|
+
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
|
18
|
+
- uses: actions/setup-python@v5
|
|
19
|
+
with:
|
|
20
|
+
python-version: 3.x
|
|
21
|
+
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
|
22
|
+
- uses: actions/cache@v4
|
|
23
|
+
with:
|
|
24
|
+
key: mkdocs-material-${{ env.cache_id }}
|
|
25
|
+
path: .cache
|
|
26
|
+
restore-keys: |
|
|
27
|
+
mkdocs-material-
|
|
28
|
+
- name: Install uv
|
|
29
|
+
uses: astral-sh/setup-uv@v5
|
|
30
|
+
- name: Install the project
|
|
31
|
+
run: uv sync --locked --all-extras --dev && uv pip install -e .
|
|
32
|
+
- run: uv run mkdocs gh-deploy --force
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
name: Claude Code Review
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
pull_request:
|
|
5
|
+
types: [opened, synchronize]
|
|
6
|
+
# Optional: Only run on specific file changes
|
|
7
|
+
# paths:
|
|
8
|
+
# - "src/**/*.ts"
|
|
9
|
+
# - "src/**/*.tsx"
|
|
10
|
+
# - "src/**/*.js"
|
|
11
|
+
# - "src/**/*.jsx"
|
|
12
|
+
|
|
13
|
+
jobs:
|
|
14
|
+
claude-review:
|
|
15
|
+
# Optional: Filter by PR author
|
|
16
|
+
# if: |
|
|
17
|
+
# github.event.pull_request.user.login == 'external-contributor' ||
|
|
18
|
+
# github.event.pull_request.user.login == 'new-developer' ||
|
|
19
|
+
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
|
|
20
|
+
|
|
21
|
+
runs-on: ubuntu-latest
|
|
22
|
+
permissions:
|
|
23
|
+
contents: read
|
|
24
|
+
pull-requests: read
|
|
25
|
+
issues: read
|
|
26
|
+
id-token: write
|
|
27
|
+
|
|
28
|
+
steps:
|
|
29
|
+
- name: Checkout repository
|
|
30
|
+
uses: actions/checkout@v4
|
|
31
|
+
with:
|
|
32
|
+
fetch-depth: 1
|
|
33
|
+
|
|
34
|
+
- name: Run Claude Code Review
|
|
35
|
+
id: claude-review
|
|
36
|
+
uses: anthropics/claude-code-action@v1
|
|
37
|
+
with:
|
|
38
|
+
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
|
39
|
+
prompt: |
|
|
40
|
+
REPO: ${{ github.repository }}
|
|
41
|
+
PR NUMBER: ${{ github.event.pull_request.number }}
|
|
42
|
+
|
|
43
|
+
Please review this pull request and provide feedback on:
|
|
44
|
+
- Code quality and best practices
|
|
45
|
+
- Potential bugs or issues
|
|
46
|
+
- Performance considerations
|
|
47
|
+
- Security concerns
|
|
48
|
+
- Test coverage
|
|
49
|
+
|
|
50
|
+
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
|
|
51
|
+
|
|
52
|
+
Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR.
|
|
53
|
+
|
|
54
|
+
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
|
55
|
+
# or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
|
|
56
|
+
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'
|
|
57
|
+
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
name: Claude Code
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
issue_comment:
|
|
5
|
+
types: [created]
|
|
6
|
+
pull_request_review_comment:
|
|
7
|
+
types: [created]
|
|
8
|
+
issues:
|
|
9
|
+
types: [opened, assigned]
|
|
10
|
+
pull_request_review:
|
|
11
|
+
types: [submitted]
|
|
12
|
+
|
|
13
|
+
jobs:
|
|
14
|
+
claude:
|
|
15
|
+
if: |
|
|
16
|
+
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
|
|
17
|
+
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
|
|
18
|
+
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
|
|
19
|
+
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
|
|
20
|
+
runs-on: ubuntu-latest
|
|
21
|
+
permissions:
|
|
22
|
+
contents: read
|
|
23
|
+
pull-requests: read
|
|
24
|
+
issues: read
|
|
25
|
+
id-token: write
|
|
26
|
+
actions: read # Required for Claude to read CI results on PRs
|
|
27
|
+
steps:
|
|
28
|
+
- name: Checkout repository
|
|
29
|
+
uses: actions/checkout@v4
|
|
30
|
+
with:
|
|
31
|
+
fetch-depth: 1
|
|
32
|
+
|
|
33
|
+
- name: Run Claude Code
|
|
34
|
+
id: claude
|
|
35
|
+
uses: anthropics/claude-code-action@v1
|
|
36
|
+
with:
|
|
37
|
+
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
|
38
|
+
|
|
39
|
+
# This is an optional setting that allows Claude to read CI results on PRs
|
|
40
|
+
additional_permissions: |
|
|
41
|
+
actions: read
|
|
42
|
+
|
|
43
|
+
# Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it.
|
|
44
|
+
# prompt: 'Update the pull request description to include a summary of changes.'
|
|
45
|
+
|
|
46
|
+
# Optional: Add claude_args to customize behavior and configuration
|
|
47
|
+
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
|
48
|
+
# or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
|
|
49
|
+
# claude_args: '--allowed-tools Bash(gh pr:*)'
|
|
50
|
+
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
name: "Publish release to PyPI"
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- v*
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
run:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
environment:
|
|
12
|
+
name: pypi
|
|
13
|
+
permissions:
|
|
14
|
+
id-token: write
|
|
15
|
+
contents: read
|
|
16
|
+
steps:
|
|
17
|
+
- name: Checkout
|
|
18
|
+
uses: actions/checkout@v4
|
|
19
|
+
- name: Install uv
|
|
20
|
+
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
|
21
|
+
- name: Install Python
|
|
22
|
+
run: uv python install 3.13
|
|
23
|
+
- name: Build
|
|
24
|
+
run: uv build
|
|
25
|
+
- name: Publish
|
|
26
|
+
run: uv publish
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
bin
|
|
2
|
+
tmp
|
|
3
|
+
weights
|
|
4
|
+
.claude
|
|
5
|
+
tests/tokamax_cache
|
|
6
|
+
# pixi environments
|
|
7
|
+
.pixi
|
|
8
|
+
*.egg-info
|
|
9
|
+
.aider*
|
|
10
|
+
|
|
11
|
+
# Byte-compiled / optimized / DLL files
|
|
12
|
+
__pycache__/
|
|
13
|
+
*.py[cod]
|
|
14
|
+
*$py.class
|
|
15
|
+
|
|
16
|
+
# C extensions
|
|
17
|
+
*.so
|
|
18
|
+
|
|
19
|
+
# Distribution / packaging
|
|
20
|
+
.Python
|
|
21
|
+
build/
|
|
22
|
+
develop-eggs/
|
|
23
|
+
dist/
|
|
24
|
+
downloads/
|
|
25
|
+
eggs/
|
|
26
|
+
.eggs/
|
|
27
|
+
lib/
|
|
28
|
+
lib64/
|
|
29
|
+
parts/
|
|
30
|
+
sdist/
|
|
31
|
+
var/
|
|
32
|
+
wheels/
|
|
33
|
+
share/python-wheels/
|
|
34
|
+
*.egg-info/
|
|
35
|
+
.installed.cfg
|
|
36
|
+
*.egg
|
|
37
|
+
MANIFEST
|
|
38
|
+
|
|
39
|
+
# PyInstaller
|
|
40
|
+
# Usually these files are written by a python script from a template
|
|
41
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
42
|
+
*.manifest
|
|
43
|
+
*.spec
|
|
44
|
+
|
|
45
|
+
# Installer logs
|
|
46
|
+
pip-log.txt
|
|
47
|
+
pip-delete-this-directory.txt
|
|
48
|
+
|
|
49
|
+
# Unit test / coverage reports
|
|
50
|
+
htmlcov/
|
|
51
|
+
.tox/
|
|
52
|
+
.nox/
|
|
53
|
+
.coverage
|
|
54
|
+
.coverage.*
|
|
55
|
+
.cache
|
|
56
|
+
nosetests.xml
|
|
57
|
+
coverage.xml
|
|
58
|
+
*.cover
|
|
59
|
+
*.py,cover
|
|
60
|
+
.hypothesis/
|
|
61
|
+
.pytest_cache/
|
|
62
|
+
cover/
|
|
63
|
+
|
|
64
|
+
# Translations
|
|
65
|
+
*.mo
|
|
66
|
+
*.pot
|
|
67
|
+
|
|
68
|
+
# Django stuff:
|
|
69
|
+
*.log
|
|
70
|
+
local_settings.py
|
|
71
|
+
db.sqlite3
|
|
72
|
+
db.sqlite3-journal
|
|
73
|
+
|
|
74
|
+
# Flask stuff:
|
|
75
|
+
instance/
|
|
76
|
+
.webassets-cache
|
|
77
|
+
|
|
78
|
+
# Scrapy stuff:
|
|
79
|
+
.scrapy
|
|
80
|
+
|
|
81
|
+
# Sphinx documentation
|
|
82
|
+
docs/_build/
|
|
83
|
+
|
|
84
|
+
# PyBuilder
|
|
85
|
+
.pybuilder/
|
|
86
|
+
target/
|
|
87
|
+
|
|
88
|
+
# Jupyter Notebook
|
|
89
|
+
.ipynb_checkpoints
|
|
90
|
+
|
|
91
|
+
# IPython
|
|
92
|
+
profile_default/
|
|
93
|
+
ipython_config.py
|
|
94
|
+
|
|
95
|
+
# pyenv
|
|
96
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
97
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
98
|
+
# .python-version
|
|
99
|
+
|
|
100
|
+
# pipenv
|
|
101
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
102
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
103
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
104
|
+
# install all needed dependencies.
|
|
105
|
+
#Pipfile.lock
|
|
106
|
+
|
|
107
|
+
# UV
|
|
108
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
|
109
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
110
|
+
# commonly ignored for libraries.
|
|
111
|
+
#uv.lock
|
|
112
|
+
|
|
113
|
+
# poetry
|
|
114
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
115
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
116
|
+
# commonly ignored for libraries.
|
|
117
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
118
|
+
#poetry.lock
|
|
119
|
+
#poetry.toml
|
|
120
|
+
|
|
121
|
+
# pdm
|
|
122
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
123
|
+
#pdm.lock
|
|
124
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
125
|
+
# in version control.
|
|
126
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
|
127
|
+
.pdm.toml
|
|
128
|
+
.pdm-python
|
|
129
|
+
.pdm-build/
|
|
130
|
+
|
|
131
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
132
|
+
__pypackages__/
|
|
133
|
+
|
|
134
|
+
# Celery stuff
|
|
135
|
+
celerybeat-schedule
|
|
136
|
+
celerybeat.pid
|
|
137
|
+
|
|
138
|
+
# SageMath parsed files
|
|
139
|
+
*.sage.py
|
|
140
|
+
|
|
141
|
+
# Environments
|
|
142
|
+
.env
|
|
143
|
+
.venv
|
|
144
|
+
env/
|
|
145
|
+
venv/
|
|
146
|
+
ENV/
|
|
147
|
+
env.bak/
|
|
148
|
+
venv.bak/
|
|
149
|
+
|
|
150
|
+
# Spyder project settings
|
|
151
|
+
.spyderproject
|
|
152
|
+
.spyproject
|
|
153
|
+
|
|
154
|
+
# Rope project settings
|
|
155
|
+
.ropeproject
|
|
156
|
+
|
|
157
|
+
# mkdocs documentation
|
|
158
|
+
/site
|
|
159
|
+
|
|
160
|
+
# mypy
|
|
161
|
+
.mypy_cache/
|
|
162
|
+
.dmypy.json
|
|
163
|
+
dmypy.json
|
|
164
|
+
|
|
165
|
+
# Pyre type checker
|
|
166
|
+
.pyre/
|
|
167
|
+
|
|
168
|
+
# pytype static type analyzer
|
|
169
|
+
.pytype/
|
|
170
|
+
|
|
171
|
+
# Cython debug symbols
|
|
172
|
+
cython_debug/
|
|
173
|
+
|
|
174
|
+
# PyCharm
|
|
175
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
176
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
177
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
178
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
179
|
+
#.idea/
|
|
180
|
+
|
|
181
|
+
# Abstra
|
|
182
|
+
# Abstra is an AI-powered process automation framework.
|
|
183
|
+
# Ignore directories containing user credentials, local state, and settings.
|
|
184
|
+
# Learn more at https://abstra.io/docs
|
|
185
|
+
.abstra/
|
|
186
|
+
|
|
187
|
+
# Visual Studio Code
|
|
188
|
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
|
189
|
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
|
190
|
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
|
191
|
+
# you could uncomment the following to ignore the entire vscode folder
|
|
192
|
+
# .vscode/
|
|
193
|
+
|
|
194
|
+
# Ruff stuff:
|
|
195
|
+
.ruff_cache/
|
|
196
|
+
|
|
197
|
+
# PyPI configuration file
|
|
198
|
+
.pypirc
|
|
199
|
+
|
|
200
|
+
# Cursor
|
|
201
|
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
|
202
|
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
|
203
|
+
# refer to https://docs.cursor.com/context/ignore-files
|
|
204
|
+
.cursorignore
|
|
205
|
+
.cursorindexingignore
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
repos:
|
|
2
|
+
# Standard hooks
|
|
3
|
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
4
|
+
rev: v6.0.0
|
|
5
|
+
hooks:
|
|
6
|
+
- id: check-case-conflict
|
|
7
|
+
- id: check-docstring-first
|
|
8
|
+
- id: check-merge-conflict
|
|
9
|
+
- id: check-symlinks
|
|
10
|
+
- id: check-toml
|
|
11
|
+
- id: debug-statements
|
|
12
|
+
- id: mixed-line-ending
|
|
13
|
+
- id: requirements-txt-fixer
|
|
14
|
+
- id: trailing-whitespace
|
|
15
|
+
|
|
16
|
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
17
|
+
# Ruff version.
|
|
18
|
+
rev: v0.14.1
|
|
19
|
+
hooks:
|
|
20
|
+
# Run the linter.
|
|
21
|
+
- id: ruff
|
|
22
|
+
args: [ --fix ]
|
|
23
|
+
# Sort imports
|
|
24
|
+
- id: ruff
|
|
25
|
+
args: ["check", "--select", "I", "--fix"]
|
|
26
|
+
# Run the formatter.
|
|
27
|
+
- id: ruff-format
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.11
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Pinak Paliwal
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jax-image-models
|
|
3
|
+
Version: 0.3.3
|
|
4
|
+
Summary: Jax Image Modeling of Models
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
|
+
Requires-Dist: flax>=0.10.6
|
|
8
|
+
Requires-Dist: jax>=0.6.2
|
|
9
|
+
Requires-Dist: jaxtyping>=0.3.2
|
|
10
|
+
Requires-Dist: safetensors>=0.5.3
|
|
11
|
+
Requires-Dist: tokamax>=0.0.12
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
|
|
14
|
+
# Jax Image Modeling of Models (jimm)
|
|
15
|
+
Docs are at: [https://pythoncrazy.github.io/jimm](https://pythoncrazy.github.io/jimm)
|
|
16
|
+
- This aims to be the jax counterpart to timm, with the exception that for image-text models (CLIP, SigLIP, etc), we support the text model entirely.
|
|
17
|
+
- Made with flax nnx, supports weight loading from pytorch_model.bin and safetensors (as well as both methods from huggingface).
|
|
18
|
+
|
|
19
|
+
Models Supported:
|
|
20
|
+
- Vision Transformers
|
|
21
|
+
- Both with a classification linear layer, or not
|
|
22
|
+
- Using a CLS Token for pooling, or using Multihead Attention Pooling
|
|
23
|
+
- Can load any standard variant of Vision Transformers of any size/resolution(e.g. "google/vit-base-patch16-224" or "google/vit-large-patch16-384")
|
|
24
|
+
- CLIP
|
|
25
|
+
- Can load from any checkpoints of the clip model on github (such as "openai/clip-vit-base-patch32" or "geolocal/StreetCLIP")
|
|
26
|
+
- SigLIP
|
|
27
|
+
- Can load any non-naflex version of the SigLIP model, from both siglipv1 and siglipv2 (eg "google/siglip-base-patch16-256" or "google/siglip2-large-patch16-512" from huggingface or locally)
|
|
28
|
+
## Installation
|
|
29
|
+
### Using pixi.sh:
|
|
30
|
+
`pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
|
|
31
|
+
### Using uv
|
|
32
|
+
`uv add --dev git+https://github.com/pythoncrazy/jimm.git`
|
|
33
|
+
or if you prefer to not add as a direct dependency:
|
|
34
|
+
`uv pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
35
|
+
### Using pip/conda
|
|
36
|
+
`pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Jax Image Modeling of Models (jimm)
|
|
2
|
+
Docs are at: [https://pythoncrazy.github.io/jimm](https://pythoncrazy.github.io/jimm)
|
|
3
|
+
- This aims to be the jax counterpart to timm, with the exception that for image-text models (CLIP, SigLIP, etc), we support the text model entirely.
|
|
4
|
+
- Made with flax nnx, supports weight loading from pytorch_model.bin and safetensors (as well as both methods from huggingface).
|
|
5
|
+
|
|
6
|
+
Models Supported:
|
|
7
|
+
- Vision Transformers
|
|
8
|
+
- Both with a classification linear layer, or not
|
|
9
|
+
- Using a CLS Token for pooling, or using Multihead Attention Pooling
|
|
10
|
+
- Can load any standard variant of Vision Transformers of any size/resolution(e.g. "google/vit-base-patch16-224" or "google/vit-large-patch16-384")
|
|
11
|
+
- CLIP
|
|
12
|
+
- Can load from any checkpoints of the clip model on github (such as "openai/clip-vit-base-patch32" or "geolocal/StreetCLIP")
|
|
13
|
+
- SigLIP
|
|
14
|
+
- Can load any non-naflex version of the SigLIP model, from both siglipv1 and siglipv2 (eg "google/siglip-base-patch16-256" or "google/siglip2-large-patch16-512" from huggingface or locally)
|
|
15
|
+
## Installation
|
|
16
|
+
### Using pixi.sh:
|
|
17
|
+
`pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
|
|
18
|
+
### Using uv
|
|
19
|
+
`uv add --dev git+https://github.com/pythoncrazy/jimm.git`
|
|
20
|
+
or if you prefer to not add as a direct dependency:
|
|
21
|
+
`uv pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
22
|
+
### Using pip/conda
|
|
23
|
+
`pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# jimm docs
|
|
2
|
+
|
|
3
|
+
jimm is a JAX image-model library built on Flax NNX. It supports loading pretrained weights from HuggingFace, hardware-accelerated attention, and FSDP-style explicit sharding.
|
|
4
|
+
|
|
5
|
+
## Models Implemented
|
|
6
|
+
|
|
7
|
+
* Vision Transformers (CLS pooling or Multihead Attention Pooling)
|
|
8
|
+
* CLIP
|
|
9
|
+
* SigLIP
|
|
10
|
+
* more tbd — contribute, it's open source!
|
|
11
|
+
|
|
12
|
+
## Flash / Splash Attention (via Tokamax)
|
|
13
|
+
|
|
14
|
+
All models accept an `attention_fn` argument for hardware-accelerated attention using [Tokamax](https://github.com/openxla/tokamax):
|
|
15
|
+
|
|
16
|
+
| Backend | Hardware | Notes |
|
|
17
|
+
|---------|----------|-------|
|
|
18
|
+
| `"mosaic"` | NVIDIA H100 (SM90) / B100 (SM100) | Pallas Mosaic GPU kernel |
|
|
19
|
+
| `"triton"` | Any NVIDIA GPU | Pallas Triton kernel |
|
|
20
|
+
| `"cudnn"` | NVIDIA GPU | Via JAX-NN / cuDNN |
|
|
21
|
+
| `"mosaic_tpu"` | TPU v5 / v7 | Splash attention (block-sparse) |
|
|
22
|
+
| `"xla_chunked"` | GPU / TPU | Flash-style chunked XLA |
|
|
23
|
+
| `"xla"` | Any | Standard XLA fallback |
|
|
24
|
+
|
|
25
|
+
Pass a list for automatic fallback:
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
import jimm
|
|
29
|
+
|
|
30
|
+
# GPU: try H100 Mosaic kernel, fall back to Triton, then XLA
|
|
31
|
+
model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
|
|
32
|
+
attention_fn=jimm.make_tokamax_attention(["mosaic", "triton", "xla"]))
|
|
33
|
+
|
|
34
|
+
# TPU: try Splash attention, fall back to chunked XLA
|
|
35
|
+
model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
|
|
36
|
+
attention_fn=jimm.make_tokamax_attention(["mosaic_tpu", "xla_chunked"]))
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
> **Note:** Flash/Splash attention does not provide a speedup at typical vision/text context lengths (e.g. 256 image tokens, 77 text tokens). The primary benefit is memory reduction at longer context lengths.
|
|
40
|
+
|
|
41
|
+
## FSDP / Explicit Sharding
|
|
42
|
+
|
|
43
|
+
All models support JAX explicit sharding (FSDP-style) out of the box. Set up a mesh before model creation:
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
from jax.experimental import mesh_utils
|
|
47
|
+
from jax.sharding import AxisType, Mesh
|
|
48
|
+
import jax
|
|
49
|
+
|
|
50
|
+
n_devices = jax.device_count()
|
|
51
|
+
mesh = Mesh(
|
|
52
|
+
mesh_utils.create_device_mesh((1, n_devices)),
|
|
53
|
+
("data", "fsdp"),
|
|
54
|
+
axis_types=(AxisType.Explicit, AxisType.Explicit),
|
|
55
|
+
)
|
|
56
|
+
jax.set_mesh(mesh)
|
|
57
|
+
|
|
58
|
+
model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14")
|
|
59
|
+
# params are automatically sharded across fsdp axis
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
Each model ships with a default sharding config (`CLIPSharding`, `SigLIPSharding`, `ViTSharding`) that shards large weight matrices on the contracting (`in_features`) dimension, keeping activations batch-sharded only. Specs are per-layer shapes; the Transformer stack patches Variable metadata after `nnx.vmap` so the optimizer (e.g. AdamW via `nnx.Optimizer`) initialises its state with the correct stacked spec — no manual fixups needed.
|
|
63
|
+
|
|
64
|
+
Pass `sharding=jimm.common.sharding.NoSharding()` to disable all sharding.
|
|
65
|
+
|
|
66
|
+
## Installation
|
|
67
|
+
### Using pixi.sh:
|
|
68
|
+
`pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
|
|
69
|
+
### Using uv
|
|
70
|
+
`uv add git+https://github.com/pythoncrazy/jimm.git`
|
|
71
|
+
or if you prefer to not add as a direct dependency:
|
|
72
|
+
`uv pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
73
|
+
### Using pip/conda
|
|
74
|
+
`pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# CLIP (Contrastive Language–Image Pre-training)
|
|
2
|
+
|
|
3
|
+
CLIP (Contrastive Language–Image Pre-training) is a neural network architecture that learns visual concepts from natural language supervision. It is trained on a large dataset of image-text pairs to create a unified vision-language model that can understand both images and text in a shared semantic space.
|
|
4
|
+
|
|
5
|
+
CLIP consists of two main components:
|
|
6
|
+
1. A vision encoder (Vision Transformer) that processes images into visual features
|
|
7
|
+
2. A text encoder (Transformer) that processes text into textual features
|
|
8
|
+
|
|
9
|
+
The model is trained using contrastive learning, where it learns to maximize the cosine similarity between the embeddings of matching image-text pairs while minimizing it for non-matching pairs. This allows CLIP to perform zero-shot classification by comparing image embeddings with text embeddings of potential labels.
|
|
10
|
+
|
|
11
|
+
CLIP was introduced in the paper ["Learning Transferable Visual Models From Natural Language Supervision"](https://arxiv.org/abs/2103.00020) and has shown remarkable zero-shot generalization capabilities across a wide range of visual classification tasks. The CLIP model combines a Vision Transformer and a Text Transformer to learn joint representations of images and text. It is trained to maximize the similarity between matching image-text pairs while minimizing similarity between non-matching pairs.
|
|
12
|
+
|
|
13
|
+
## Flash / Splash Attention
|
|
14
|
+
|
|
15
|
+
CLIP supports hardware-accelerated attention via [Tokamax](https://github.com/openxla/tokamax). Pass an `attention_fn` at construction time:
|
|
16
|
+
|
|
17
|
+
| Backend | Hardware | Notes |
|
|
18
|
+
|---------|----------|-------|
|
|
19
|
+
| `"mosaic"` | NVIDIA H100 (SM90) / B100 (SM100) | Pallas Mosaic GPU kernel |
|
|
20
|
+
| `"triton"` | Any NVIDIA GPU | Pallas Triton kernel |
|
|
21
|
+
| `"cudnn"` | NVIDIA GPU | Via JAX-NN / cuDNN |
|
|
22
|
+
| `"mosaic_tpu"` | TPU v5 / v7 | Splash attention (block-sparse) |
|
|
23
|
+
| `"xla_chunked"` | GPU / TPU | Flash-style chunked XLA |
|
|
24
|
+
| `"xla"` | Any | Standard XLA fallback |
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
import jimm
|
|
28
|
+
|
|
29
|
+
# GPU: try H100 Mosaic kernel, fall back to Triton, then XLA
|
|
30
|
+
model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
|
|
31
|
+
attention_fn=jimm.make_tokamax_attention(["mosaic", "triton", "xla"]))
|
|
32
|
+
|
|
33
|
+
# TPU: try Splash attention, fall back to chunked XLA
|
|
34
|
+
model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
|
|
35
|
+
attention_fn=jimm.make_tokamax_attention(["mosaic_tpu", "xla_chunked"]))
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
You can also apply different kernels to each encoder via `vision_attention_fn` and `text_attention_fn`.
|
|
39
|
+
|
|
40
|
+
> **Note:** Flash/Splash attention does not provide a speedup at typical CLIP context lengths (256 image tokens, 77 text tokens). The primary benefit is memory reduction at longer sequence lengths.
|
|
41
|
+
|
|
42
|
+
## FSDP / Explicit Sharding
|
|
43
|
+
|
|
44
|
+
CLIP supports JAX explicit sharding (FSDP-style) out of the box via `CLIPSharding`. Large weight matrices are sharded on the contracting (`in_features`) dimension so that activations carry only the batch-axis sharding, avoiding duplicate-axis conflicts.
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from jax.experimental import mesh_utils
|
|
48
|
+
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P
|
|
49
|
+
import jax
|
|
50
|
+
|
|
51
|
+
n_devices = jax.device_count()
|
|
52
|
+
mesh = Mesh(
|
|
53
|
+
mesh_utils.create_device_mesh((1, n_devices)),
|
|
54
|
+
("data", "fsdp"),
|
|
55
|
+
axis_types=(AxisType.Explicit, AxisType.Explicit),
|
|
56
|
+
)
|
|
57
|
+
jax.set_mesh(mesh)
|
|
58
|
+
|
|
59
|
+
model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14")
|
|
60
|
+
# model params are automatically sharded across fsdp axis
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
`CLIPSharding` specs represent **per-layer** shapes. The `Transformer` stack prepends `None` for the scan axis to the Variable metadata after `nnx.vmap`, so the optimizer (e.g. `nnx.Optimizer` with AdamW) receives the correct stacked spec and initialises its state without any manual fixups.
|
|
64
|
+
|
|
65
|
+
To disable sharding, pass `sharding=jimm.common.sharding.NoSharding()`.
|
|
66
|
+
|
|
67
|
+
::: jimm.models.clip.CLIPVisionModel
|
|
68
|
+
options:
|
|
69
|
+
show_root_heading: true
|
|
70
|
+
show_source: true
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
::: jimm.models.clip.CLIPTextModel
|
|
74
|
+
options:
|
|
75
|
+
show_root_heading: true
|
|
76
|
+
show_source: true
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
::: jimm.models.clip.CLIP
|
|
80
|
+
options:
|
|
81
|
+
show_root_heading: true
|
|
82
|
+
show_source: true
|