jax-image-models 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. jax_image_models-0.3.3/.gitattributes +2 -0
  2. jax_image_models-0.3.3/.github/workflows/ci.yml +32 -0
  3. jax_image_models-0.3.3/.github/workflows/claude-code-review.yml +57 -0
  4. jax_image_models-0.3.3/.github/workflows/claude.yml +50 -0
  5. jax_image_models-0.3.3/.github/workflows/release.yaml +26 -0
  6. jax_image_models-0.3.3/.gitignore +205 -0
  7. jax_image_models-0.3.3/.pre-commit-config.yaml +27 -0
  8. jax_image_models-0.3.3/.python-version +1 -0
  9. jax_image_models-0.3.3/.vscode/settings.json +3 -0
  10. jax_image_models-0.3.3/LICENSE +21 -0
  11. jax_image_models-0.3.3/PKG-INFO +36 -0
  12. jax_image_models-0.3.3/README.md +23 -0
  13. jax_image_models-0.3.3/docs/index.md +74 -0
  14. jax_image_models-0.3.3/docs/models/CLIP.md +82 -0
  15. jax_image_models-0.3.3/docs/models/SigLIP.md +80 -0
  16. jax_image_models-0.3.3/docs/models/ViT.md +61 -0
  17. jax_image_models-0.3.3/examples/bench_fsdp.py +130 -0
  18. jax_image_models-0.3.3/examples/benchmark_clip_loss.py +102 -0
  19. jax_image_models-0.3.3/examples/clip_inference.py +85 -0
  20. jax_image_models-0.3.3/examples/clip_training_nnxjit.py +273 -0
  21. jax_image_models-0.3.3/examples/model_definitions.ipynb +168 -0
  22. jax_image_models-0.3.3/examples/saving_example.py +32 -0
  23. jax_image_models-0.3.3/examples/siglip_inference.py +67 -0
  24. jax_image_models-0.3.3/examples/vit_inference.py +74 -0
  25. jax_image_models-0.3.3/examples/vit_training.py +243 -0
  26. jax_image_models-0.3.3/images/test_image.jpg +0 -0
  27. jax_image_models-0.3.3/mkdocs.yml +28 -0
  28. jax_image_models-0.3.3/pyproject.toml +73 -0
  29. jax_image_models-0.3.3/ruff.toml +77 -0
  30. jax_image_models-0.3.3/src/jimm/__init__.py +33 -0
  31. jax_image_models-0.3.3/src/jimm/common/autotuning.py +150 -0
  32. jax_image_models-0.3.3/src/jimm/common/loading_utils.py +248 -0
  33. jax_image_models-0.3.3/src/jimm/common/sharding.py +98 -0
  34. jax_image_models-0.3.3/src/jimm/common/tokamax_attention.py +108 -0
  35. jax_image_models-0.3.3/src/jimm/common/transformer.py +299 -0
  36. jax_image_models-0.3.3/src/jimm/common/utils.py +157 -0
  37. jax_image_models-0.3.3/src/jimm/common/vit.py +340 -0
  38. jax_image_models-0.3.3/src/jimm/models/__init__.py +13 -0
  39. jax_image_models-0.3.3/src/jimm/models/clip/__init__.py +3 -0
  40. jax_image_models-0.3.3/src/jimm/models/clip/clip_model.py +635 -0
  41. jax_image_models-0.3.3/src/jimm/models/clip/params.py +520 -0
  42. jax_image_models-0.3.3/src/jimm/models/clip/sharding.py +57 -0
  43. jax_image_models-0.3.3/src/jimm/models/siglip/__init__.py +3 -0
  44. jax_image_models-0.3.3/src/jimm/models/siglip/params.py +608 -0
  45. jax_image_models-0.3.3/src/jimm/models/siglip/sharding.py +56 -0
  46. jax_image_models-0.3.3/src/jimm/models/siglip/siglip_model.py +614 -0
  47. jax_image_models-0.3.3/src/jimm/models/vit/__init__.py +3 -0
  48. jax_image_models-0.3.3/src/jimm/models/vit/params.py +272 -0
  49. jax_image_models-0.3.3/src/jimm/models/vit/sharding.py +55 -0
  50. jax_image_models-0.3.3/src/jimm/models/vit/vit_model.py +209 -0
  51. jax_image_models-0.3.3/tests/benchmark_utils.py +73 -0
  52. jax_image_models-0.3.3/tests/conftest.py +11 -0
  53. jax_image_models-0.3.3/tests/test_checkpointing.py +135 -0
  54. jax_image_models-0.3.3/tests/test_clip.py +387 -0
  55. jax_image_models-0.3.3/tests/test_loading_utils.py +92 -0
  56. jax_image_models-0.3.3/tests/test_siglip.py +335 -0
  57. jax_image_models-0.3.3/tests/test_transformer.py +60 -0
  58. jax_image_models-0.3.3/tests/test_vit.py +270 -0
  59. jax_image_models-0.3.3/uv.lock +4760 -0
@@ -0,0 +1,2 @@
1
+ # SCM syntax highlighting & preventing 3-way merges
2
+ pixi.lock merge=binary linguist-language=YAML linguist-generated=true
@@ -0,0 +1,32 @@
1
+ name: ci
2
+ on:
3
+ push:
4
+ branches:
5
+ - master
6
+ - main
7
+ permissions:
8
+ contents: write
9
+ jobs:
10
+ deploy:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ - name: Configure Git Credentials
15
+ run: |
16
+ git config user.name github-actions[bot]
17
+ git config user.email 41898282+github-actions[bot]@users.noreply.github.com
18
+ - uses: actions/setup-python@v5
19
+ with:
20
+ python-version: 3.x
21
+ - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
22
+ - uses: actions/cache@v4
23
+ with:
24
+ key: mkdocs-material-${{ env.cache_id }}
25
+ path: .cache
26
+ restore-keys: |
27
+ mkdocs-material-
28
+ - name: Install uv
29
+ uses: astral-sh/setup-uv@v5
30
+ - name: Install the project
31
+ run: uv sync --locked --all-extras --dev && uv pip install -e .
32
+ - run: uv run mkdocs gh-deploy --force
@@ -0,0 +1,57 @@
1
+ name: Claude Code Review
2
+
3
+ on:
4
+ pull_request:
5
+ types: [opened, synchronize]
6
+ # Optional: Only run on specific file changes
7
+ # paths:
8
+ # - "src/**/*.ts"
9
+ # - "src/**/*.tsx"
10
+ # - "src/**/*.js"
11
+ # - "src/**/*.jsx"
12
+
13
+ jobs:
14
+ claude-review:
15
+ # Optional: Filter by PR author
16
+ # if: |
17
+ # github.event.pull_request.user.login == 'external-contributor' ||
18
+ # github.event.pull_request.user.login == 'new-developer' ||
19
+ # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
20
+
21
+ runs-on: ubuntu-latest
22
+ permissions:
23
+ contents: read
24
+ pull-requests: read
25
+ issues: read
26
+ id-token: write
27
+
28
+ steps:
29
+ - name: Checkout repository
30
+ uses: actions/checkout@v4
31
+ with:
32
+ fetch-depth: 1
33
+
34
+ - name: Run Claude Code Review
35
+ id: claude-review
36
+ uses: anthropics/claude-code-action@v1
37
+ with:
38
+ claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
39
+ prompt: |
40
+ REPO: ${{ github.repository }}
41
+ PR NUMBER: ${{ github.event.pull_request.number }}
42
+
43
+ Please review this pull request and provide feedback on:
44
+ - Code quality and best practices
45
+ - Potential bugs or issues
46
+ - Performance considerations
47
+ - Security concerns
48
+ - Test coverage
49
+
50
+ Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
51
+
52
+ Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR.
53
+
54
+ # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
55
+ # or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
56
+ claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'
57
+
@@ -0,0 +1,50 @@
1
+ name: Claude Code
2
+
3
+ on:
4
+ issue_comment:
5
+ types: [created]
6
+ pull_request_review_comment:
7
+ types: [created]
8
+ issues:
9
+ types: [opened, assigned]
10
+ pull_request_review:
11
+ types: [submitted]
12
+
13
+ jobs:
14
+ claude:
15
+ if: |
16
+ (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
17
+ (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
18
+ (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
19
+ (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
20
+ runs-on: ubuntu-latest
21
+ permissions:
22
+ contents: read
23
+ pull-requests: read
24
+ issues: read
25
+ id-token: write
26
+ actions: read # Required for Claude to read CI results on PRs
27
+ steps:
28
+ - name: Checkout repository
29
+ uses: actions/checkout@v4
30
+ with:
31
+ fetch-depth: 1
32
+
33
+ - name: Run Claude Code
34
+ id: claude
35
+ uses: anthropics/claude-code-action@v1
36
+ with:
37
+ claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
38
+
39
+ # This is an optional setting that allows Claude to read CI results on PRs
40
+ additional_permissions: |
41
+ actions: read
42
+
43
+ # Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it.
44
+ # prompt: 'Update the pull request description to include a summary of changes.'
45
+
46
+ # Optional: Add claude_args to customize behavior and configuration
47
+ # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
48
+ # or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
49
+ # claude_args: '--allowed-tools Bash(gh pr:*)'
50
+
@@ -0,0 +1,26 @@
1
+ name: "Publish release to PyPI"
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - v*
7
+
8
+ jobs:
9
+ run:
10
+ runs-on: ubuntu-latest
11
+ environment:
12
+ name: pypi
13
+ permissions:
14
+ id-token: write
15
+ contents: read
16
+ steps:
17
+ - name: Checkout
18
+ uses: actions/checkout@v4
19
+ - name: Install uv
20
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
21
+ - name: Install Python
22
+ run: uv python install 3.13
23
+ - name: Build
24
+ run: uv build
25
+ - name: Publish
26
+ run: uv publish
@@ -0,0 +1,205 @@
1
+ bin
2
+ tmp
3
+ weights
4
+ .claude
5
+ tests/tokamax_cache
6
+ # pixi environments
7
+ .pixi
8
+ *.egg-info
9
+ .aider*
10
+
11
+ # Byte-compiled / optimized / DLL files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+ cover/
63
+
64
+ # Translations
65
+ *.mo
66
+ *.pot
67
+
68
+ # Django stuff:
69
+ *.log
70
+ local_settings.py
71
+ db.sqlite3
72
+ db.sqlite3-journal
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ .pybuilder/
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ # For a library or package, you might want to ignore these files since the code is
97
+ # intended to run in multiple environments; otherwise, check them in:
98
+ # .python-version
99
+
100
+ # pipenv
101
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
+ # install all needed dependencies.
105
+ #Pipfile.lock
106
+
107
+ # UV
108
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ #uv.lock
112
+
113
+ # poetry
114
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
115
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
116
+ # commonly ignored for libraries.
117
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
118
+ #poetry.lock
119
+ #poetry.toml
120
+
121
+ # pdm
122
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
123
+ #pdm.lock
124
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
125
+ # in version control.
126
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
127
+ .pdm.toml
128
+ .pdm-python
129
+ .pdm-build/
130
+
131
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132
+ __pypackages__/
133
+
134
+ # Celery stuff
135
+ celerybeat-schedule
136
+ celerybeat.pid
137
+
138
+ # SageMath parsed files
139
+ *.sage.py
140
+
141
+ # Environments
142
+ .env
143
+ .venv
144
+ env/
145
+ venv/
146
+ ENV/
147
+ env.bak/
148
+ venv.bak/
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # PyCharm
175
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
178
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179
+ #.idea/
180
+
181
+ # Abstra
182
+ # Abstra is an AI-powered process automation framework.
183
+ # Ignore directories containing user credentials, local state, and settings.
184
+ # Learn more at https://abstra.io/docs
185
+ .abstra/
186
+
187
+ # Visual Studio Code
188
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
189
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
190
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
191
+ # you could uncomment the following to ignore the entire vscode folder
192
+ # .vscode/
193
+
194
+ # Ruff stuff:
195
+ .ruff_cache/
196
+
197
+ # PyPI configuration file
198
+ .pypirc
199
+
200
+ # Cursor
201
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
202
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
203
+ # refer to https://docs.cursor.com/context/ignore-files
204
+ .cursorignore
205
+ .cursorindexingignore
@@ -0,0 +1,27 @@
1
+ repos:
2
+ # Standard hooks
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v6.0.0
5
+ hooks:
6
+ - id: check-case-conflict
7
+ - id: check-docstring-first
8
+ - id: check-merge-conflict
9
+ - id: check-symlinks
10
+ - id: check-toml
11
+ - id: debug-statements
12
+ - id: mixed-line-ending
13
+ - id: requirements-txt-fixer
14
+ - id: trailing-whitespace
15
+
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ # Ruff version.
18
+ rev: v0.14.1
19
+ hooks:
20
+ # Run the linter.
21
+ - id: ruff
22
+ args: [ --fix ]
23
+ # Sort imports
24
+ - id: ruff
25
+ args: ["check", "--select", "I", "--fix"]
26
+ # Run the formatter.
27
+ - id: ruff-format
@@ -0,0 +1 @@
1
+ 3.11
@@ -0,0 +1,3 @@
1
+ {
2
+ "python.languageServer": "None"
3
+ }
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Pinak Paliwal
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,36 @@
1
+ Metadata-Version: 2.4
2
+ Name: jax-image-models
3
+ Version: 0.3.3
4
+ Summary: Jax Image Modeling of Models
5
+ License-File: LICENSE
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: flax>=0.10.6
8
+ Requires-Dist: jax>=0.6.2
9
+ Requires-Dist: jaxtyping>=0.3.2
10
+ Requires-Dist: safetensors>=0.5.3
11
+ Requires-Dist: tokamax>=0.0.12
12
+ Description-Content-Type: text/markdown
13
+
14
+ # Jax Image Modeling of Models (jimm)
15
+ Docs are at: [https://pythoncrazy.github.io/jimm](https://pythoncrazy.github.io/jimm)
16
+ - This aims to be the jax counterpart to timm, with the exception that for image-text models (CLIP, SigLIP, etc), we support the text model entirely.
17
+ - Made with flax nnx, supports weight loading from pytorch_model.bin and safetensors (as well as both methods from huggingface).
18
+
19
+ Models Supported:
20
+ - Vision Transformers
21
+ - Both with a classification linear layer, or not
22
+ - Using a CLS Token for pooling, or using Multihead Attention Pooling
23
+ - Can load any standard variant of Vision Transformers of any size/resolution(e.g. "google/vit-base-patch16-224" or "google/vit-large-patch16-384")
24
+ - CLIP
25
+ - Can load from any checkpoints of the clip model on github (such as "openai/clip-vit-base-patch32" or "geolocal/StreetCLIP")
26
+ - SigLIP
27
+ - Can load any non-naflex version of the SigLIP model, from both siglipv1 and siglipv2 (eg "google/siglip-base-patch16-256" or "google/siglip2-large-patch16-512" from huggingface or locally)
28
+ ## Installation
29
+ ### Using pixi.sh:
30
+ `pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
31
+ ### Using uv
32
+ `uv add --dev git+https://github.com/pythoncrazy/jimm.git`
33
+ or if you prefer to not add as a direct dependency:
34
+ `uv pip install git+https://github.com/pythoncrazy/jimm.git`
35
+ ### Using pip/conda
36
+ `pip install git+https://github.com/pythoncrazy/jimm.git`
@@ -0,0 +1,23 @@
1
+ # Jax Image Modeling of Models (jimm)
2
+ Docs are at: [https://pythoncrazy.github.io/jimm](https://pythoncrazy.github.io/jimm)
3
+ - This aims to be the jax counterpart to timm, with the exception that for image-text models (CLIP, SigLIP, etc), we support the text model entirely.
4
+ - Made with flax nnx, supports weight loading from pytorch_model.bin and safetensors (as well as both methods from huggingface).
5
+
6
+ Models Supported:
7
+ - Vision Transformers
8
+ - Both with a classification linear layer, or not
9
+ - Using a CLS Token for pooling, or using Multihead Attention Pooling
10
+ - Can load any standard variant of Vision Transformers of any size/resolution(e.g. "google/vit-base-patch16-224" or "google/vit-large-patch16-384")
11
+ - CLIP
12
+ - Can load from any checkpoints of the clip model on github (such as "openai/clip-vit-base-patch32" or "geolocal/StreetCLIP")
13
+ - SigLIP
14
+ - Can load any non-naflex version of the SigLIP model, from both siglipv1 and siglipv2 (eg "google/siglip-base-patch16-256" or "google/siglip2-large-patch16-512" from huggingface or locally)
15
+ ## Installation
16
+ ### Using pixi.sh:
17
+ `pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
18
+ ### Using uv
19
+ `uv add --dev git+https://github.com/pythoncrazy/jimm.git`
20
+ or if you prefer to not add as a direct dependency:
21
+ `uv pip install git+https://github.com/pythoncrazy/jimm.git`
22
+ ### Using pip/conda
23
+ `pip install git+https://github.com/pythoncrazy/jimm.git`
@@ -0,0 +1,74 @@
1
+ # jimm docs
2
+
3
+ jimm is a JAX image-model library built on Flax NNX. It supports loading pretrained weights from HuggingFace, hardware-accelerated attention, and FSDP-style explicit sharding.
4
+
5
+ ## Models Implemented
6
+
7
+ * Vision Transformers (CLS pooling or Multihead Attention Pooling)
8
+ * CLIP
9
+ * SigLIP
10
+ * more tbd — contribute, it's open source!
11
+
12
+ ## Flash / Splash Attention (via Tokamax)
13
+
14
+ All models accept an `attention_fn` argument for hardware-accelerated attention using [Tokamax](https://github.com/openxla/tokamax):
15
+
16
+ | Backend | Hardware | Notes |
17
+ |---------|----------|-------|
18
+ | `"mosaic"` | NVIDIA H100 (SM90) / B100 (SM100) | Pallas Mosaic GPU kernel |
19
+ | `"triton"` | Any NVIDIA GPU | Pallas Triton kernel |
20
+ | `"cudnn"` | NVIDIA GPU | Via JAX-NN / cuDNN |
21
+ | `"mosaic_tpu"` | TPU v5 / v7 | Splash attention (block-sparse) |
22
+ | `"xla_chunked"` | GPU / TPU | Flash-style chunked XLA |
23
+ | `"xla"` | Any | Standard XLA fallback |
24
+
25
+ Pass a list for automatic fallback:
26
+
27
+ ```python
28
+ import jimm
29
+
30
+ # GPU: try H100 Mosaic kernel, fall back to Triton, then XLA
31
+ model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
32
+ attention_fn=jimm.make_tokamax_attention(["mosaic", "triton", "xla"]))
33
+
34
+ # TPU: try Splash attention, fall back to chunked XLA
35
+ model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
36
+ attention_fn=jimm.make_tokamax_attention(["mosaic_tpu", "xla_chunked"]))
37
+ ```
38
+
39
+ > **Note:** Flash/Splash attention does not provide a speedup at typical vision/text context lengths (e.g. 256 image tokens, 77 text tokens). The primary benefit is memory reduction at longer context lengths.
40
+
41
+ ## FSDP / Explicit Sharding
42
+
43
+ All models support JAX explicit sharding (FSDP-style) out of the box. Set up a mesh before model creation:
44
+
45
+ ```python
46
+ from jax.experimental import mesh_utils
47
+ from jax.sharding import AxisType, Mesh
48
+ import jax
49
+
50
+ n_devices = jax.device_count()
51
+ mesh = Mesh(
52
+ mesh_utils.create_device_mesh((1, n_devices)),
53
+ ("data", "fsdp"),
54
+ axis_types=(AxisType.Explicit, AxisType.Explicit),
55
+ )
56
+ jax.set_mesh(mesh)
57
+
58
+ model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14")
59
+ # params are automatically sharded across fsdp axis
60
+ ```
61
+
62
+ Each model ships with a default sharding config (`CLIPSharding`, `SigLIPSharding`, `ViTSharding`) that shards large weight matrices on the contracting (`in_features`) dimension, keeping activations batch-sharded only. Specs are per-layer shapes; the Transformer stack patches Variable metadata after `nnx.vmap` so the optimizer (e.g. AdamW via `nnx.Optimizer`) initialises its state with the correct stacked spec — no manual fixups needed.
63
+
64
+ Pass `sharding=jimm.common.sharding.NoSharding()` to disable all sharding.
65
+
66
+ ## Installation
67
+ ### Using pixi.sh:
68
+ `pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
69
+ ### Using uv
70
+ `uv add git+https://github.com/pythoncrazy/jimm.git`
71
+ or if you prefer to not add as a direct dependency:
72
+ `uv pip install git+https://github.com/pythoncrazy/jimm.git`
73
+ ### Using pip/conda
74
+ `pip install git+https://github.com/pythoncrazy/jimm.git`
@@ -0,0 +1,82 @@
1
+ # CLIP (Contrastive Language–Image Pre-training)
2
+
3
+ CLIP (Contrastive Language–Image Pre-training) is a neural network architecture that learns visual concepts from natural language supervision. It is trained on a large dataset of image-text pairs to create a unified vision-language model that can understand both images and text in a shared semantic space.
4
+
5
+ CLIP consists of two main components:
6
+ 1. A vision encoder (Vision Transformer) that processes images into visual features
7
+ 2. A text encoder (Transformer) that processes text into textual features
8
+
9
+ The model is trained using contrastive learning, where it learns to maximize the cosine similarity between the embeddings of matching image-text pairs while minimizing it for non-matching pairs. This allows CLIP to perform zero-shot classification by comparing image embeddings with text embeddings of potential labels.
10
+
11
+ CLIP was introduced in the paper ["Learning Transferable Visual Models From Natural Language Supervision"](https://arxiv.org/abs/2103.00020) and has shown remarkable zero-shot generalization capabilities across a wide range of visual classification tasks. The CLIP model combines a Vision Transformer and a Text Transformer to learn joint representations of images and text. It is trained to maximize the similarity between matching image-text pairs while minimizing similarity between non-matching pairs.
12
+
13
+ ## Flash / Splash Attention
14
+
15
+ CLIP supports hardware-accelerated attention via [Tokamax](https://github.com/openxla/tokamax). Pass an `attention_fn` at construction time:
16
+
17
+ | Backend | Hardware | Notes |
18
+ |---------|----------|-------|
19
+ | `"mosaic"` | NVIDIA H100 (SM90) / B100 (SM100) | Pallas Mosaic GPU kernel |
20
+ | `"triton"` | Any NVIDIA GPU | Pallas Triton kernel |
21
+ | `"cudnn"` | NVIDIA GPU | Via JAX-NN / cuDNN |
22
+ | `"mosaic_tpu"` | TPU v5 / v7 | Splash attention (block-sparse) |
23
+ | `"xla_chunked"` | GPU / TPU | Flash-style chunked XLA |
24
+ | `"xla"` | Any | Standard XLA fallback |
25
+
26
+ ```python
27
+ import jimm
28
+
29
+ # GPU: try H100 Mosaic kernel, fall back to Triton, then XLA
30
+ model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
31
+ attention_fn=jimm.make_tokamax_attention(["mosaic", "triton", "xla"]))
32
+
33
+ # TPU: try Splash attention, fall back to chunked XLA
34
+ model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14",
35
+ attention_fn=jimm.make_tokamax_attention(["mosaic_tpu", "xla_chunked"]))
36
+ ```
37
+
38
+ You can also apply different kernels to each encoder via `vision_attention_fn` and `text_attention_fn`.
39
+
40
+ > **Note:** Flash/Splash attention does not provide a speedup at typical CLIP context lengths (256 image tokens, 77 text tokens). The primary benefit is memory reduction at longer sequence lengths.
41
+
42
+ ## FSDP / Explicit Sharding
43
+
44
+ CLIP supports JAX explicit sharding (FSDP-style) out of the box via `CLIPSharding`. Large weight matrices are sharded on the contracting (`in_features`) dimension so that activations carry only the batch-axis sharding, avoiding duplicate-axis conflicts.
45
+
46
+ ```python
47
+ from jax.experimental import mesh_utils
48
+ from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P
49
+ import jax
50
+
51
+ n_devices = jax.device_count()
52
+ mesh = Mesh(
53
+ mesh_utils.create_device_mesh((1, n_devices)),
54
+ ("data", "fsdp"),
55
+ axis_types=(AxisType.Explicit, AxisType.Explicit),
56
+ )
57
+ jax.set_mesh(mesh)
58
+
59
+ model = jimm.CLIP.from_pretrained("openai/clip-vit-large-patch14")
60
+ # model params are automatically sharded across fsdp axis
61
+ ```
62
+
63
+ `CLIPSharding` specs represent **per-layer** shapes. The `Transformer` stack prepends `None` for the scan axis to the Variable metadata after `nnx.vmap`, so the optimizer (e.g. `nnx.Optimizer` with AdamW) receives the correct stacked spec and initialises its state without any manual fixups.
64
+
65
+ To disable sharding, pass `sharding=jimm.common.sharding.NoSharding()`.
66
+
67
+ ::: jimm.models.clip.CLIPVisionModel
68
+ options:
69
+ show_root_heading: true
70
+ show_source: true
71
+
72
+
73
+ ::: jimm.models.clip.CLIPTextModel
74
+ options:
75
+ show_root_heading: true
76
+ show_source: true
77
+
78
+
79
+ ::: jimm.models.clip.CLIP
80
+ options:
81
+ show_root_heading: true
82
+ show_source: true