logits-sdk 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- logits_sdk-0.1.0/.github/workflows/ci.yml +53 -0
- logits_sdk-0.1.0/.github/workflows/publish.yml +87 -0
- logits_sdk-0.1.0/.gitignore +15 -0
- logits_sdk-0.1.0/LICENSE +177 -0
- logits_sdk-0.1.0/PKG-INFO +116 -0
- logits_sdk-0.1.0/README.md +92 -0
- logits_sdk-0.1.0/pyproject.toml +48 -0
- logits_sdk-0.1.0/src/logits/__init__.py +53 -0
- logits_sdk-0.1.0/src/logits/_config.py +165 -0
- logits_sdk-0.1.0/src/logits/_path_compat.py +27 -0
- logits_sdk-0.1.0/src/logits/_rest_client.py +83 -0
- logits_sdk-0.1.0/src/logits/_service_client.py +215 -0
- logits_sdk-0.1.0/src/logits/_training_client.py +75 -0
- logits_sdk-0.1.0/src/logits/_version.py +2 -0
- logits_sdk-0.1.0/src/logits/py.typed +1 -0
- logits_sdk-0.1.0/src/logits/types/__init__.py +27 -0
- logits_sdk-0.1.0/tests/test_service_client.py +79 -0
- logits_sdk-0.1.0/tests/test_types.py +11 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
pull_request:
|
|
6
|
+
workflow_dispatch:
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
test:
|
|
10
|
+
name: Test Python ${{ matrix.python-version }}
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
strategy:
|
|
13
|
+
fail-fast: false
|
|
14
|
+
matrix:
|
|
15
|
+
python-version: ["3.11", "3.12", "3.13"]
|
|
16
|
+
|
|
17
|
+
steps:
|
|
18
|
+
- name: Check out repository
|
|
19
|
+
uses: actions/checkout@v4
|
|
20
|
+
|
|
21
|
+
- name: Set up Python
|
|
22
|
+
uses: actions/setup-python@v5
|
|
23
|
+
with:
|
|
24
|
+
python-version: ${{ matrix.python-version }}
|
|
25
|
+
|
|
26
|
+
- name: Install package and test dependencies
|
|
27
|
+
run: |
|
|
28
|
+
python -m pip install --upgrade pip
|
|
29
|
+
python -m pip install -e .
|
|
30
|
+
python -m pip install pytest pytest-timeout respx
|
|
31
|
+
|
|
32
|
+
- name: Run tests
|
|
33
|
+
run: pytest
|
|
34
|
+
|
|
35
|
+
package:
|
|
36
|
+
name: Build package
|
|
37
|
+
runs-on: ubuntu-latest
|
|
38
|
+
|
|
39
|
+
steps:
|
|
40
|
+
- name: Check out repository
|
|
41
|
+
uses: actions/checkout@v4
|
|
42
|
+
|
|
43
|
+
- name: Set up Python
|
|
44
|
+
uses: actions/setup-python@v5
|
|
45
|
+
with:
|
|
46
|
+
python-version: "3.12"
|
|
47
|
+
|
|
48
|
+
- name: Build and check distributions
|
|
49
|
+
run: |
|
|
50
|
+
python -m pip install --upgrade pip
|
|
51
|
+
python -m pip install build twine
|
|
52
|
+
python -m build
|
|
53
|
+
python -m twine check dist/*
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
name: Publish
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
workflow_dispatch:
|
|
7
|
+
inputs:
|
|
8
|
+
repository:
|
|
9
|
+
description: "Repository to publish to"
|
|
10
|
+
required: true
|
|
11
|
+
default: "testpypi"
|
|
12
|
+
type: choice
|
|
13
|
+
options:
|
|
14
|
+
- testpypi
|
|
15
|
+
- pypi
|
|
16
|
+
|
|
17
|
+
permissions:
|
|
18
|
+
contents: read
|
|
19
|
+
|
|
20
|
+
jobs:
|
|
21
|
+
build:
|
|
22
|
+
name: Build distributions
|
|
23
|
+
runs-on: ubuntu-latest
|
|
24
|
+
|
|
25
|
+
steps:
|
|
26
|
+
- name: Check out repository
|
|
27
|
+
uses: actions/checkout@v4
|
|
28
|
+
|
|
29
|
+
- name: Set up Python
|
|
30
|
+
uses: actions/setup-python@v5
|
|
31
|
+
with:
|
|
32
|
+
python-version: "3.12"
|
|
33
|
+
|
|
34
|
+
- name: Build package
|
|
35
|
+
run: |
|
|
36
|
+
python -m pip install --upgrade pip
|
|
37
|
+
python -m pip install build twine
|
|
38
|
+
python -m build
|
|
39
|
+
python -m twine check dist/*
|
|
40
|
+
|
|
41
|
+
- name: Upload distributions
|
|
42
|
+
uses: actions/upload-artifact@v4
|
|
43
|
+
with:
|
|
44
|
+
name: python-package-distributions
|
|
45
|
+
path: dist/
|
|
46
|
+
|
|
47
|
+
publish-testpypi:
|
|
48
|
+
name: Publish to TestPyPI
|
|
49
|
+
needs: build
|
|
50
|
+
if: github.event_name == 'workflow_dispatch' && inputs.repository == 'testpypi'
|
|
51
|
+
runs-on: ubuntu-latest
|
|
52
|
+
environment: testpypi
|
|
53
|
+
permissions:
|
|
54
|
+
id-token: write
|
|
55
|
+
contents: read
|
|
56
|
+
|
|
57
|
+
steps:
|
|
58
|
+
- name: Download distributions
|
|
59
|
+
uses: actions/download-artifact@v4
|
|
60
|
+
with:
|
|
61
|
+
name: python-package-distributions
|
|
62
|
+
path: dist/
|
|
63
|
+
|
|
64
|
+
- name: Publish distributions to TestPyPI
|
|
65
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
66
|
+
with:
|
|
67
|
+
repository-url: https://test.pypi.org/legacy/
|
|
68
|
+
|
|
69
|
+
publish-pypi:
|
|
70
|
+
name: Publish to PyPI
|
|
71
|
+
needs: build
|
|
72
|
+
if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && inputs.repository == 'pypi')
|
|
73
|
+
runs-on: ubuntu-latest
|
|
74
|
+
environment: pypi
|
|
75
|
+
permissions:
|
|
76
|
+
id-token: write
|
|
77
|
+
contents: read
|
|
78
|
+
|
|
79
|
+
steps:
|
|
80
|
+
- name: Download distributions
|
|
81
|
+
uses: actions/download-artifact@v4
|
|
82
|
+
with:
|
|
83
|
+
name: python-package-distributions
|
|
84
|
+
path: dist/
|
|
85
|
+
|
|
86
|
+
- name: Publish distributions to PyPI
|
|
87
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
logits_sdk-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
Apache License
|
|
2
|
+
Version 2.0, January 2004
|
|
3
|
+
http://www.apache.org/licenses/
|
|
4
|
+
|
|
5
|
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
6
|
+
|
|
7
|
+
1. Definitions.
|
|
8
|
+
|
|
9
|
+
"License" shall mean the terms and conditions for use, reproduction, and
|
|
10
|
+
distribution as defined by Sections 1 through 9 of this document.
|
|
11
|
+
|
|
12
|
+
"Licensor" shall mean the copyright owner or entity authorized by the
|
|
13
|
+
copyright owner that is granting the License.
|
|
14
|
+
|
|
15
|
+
"Legal Entity" shall mean the union of the acting entity and all other
|
|
16
|
+
entities that control, are controlled by, or are under common control with
|
|
17
|
+
that entity. For the purposes of this definition, "control" means (i) the
|
|
18
|
+
power, direct or indirect, to cause the direction or management of such
|
|
19
|
+
entity, whether by contract or otherwise, or (ii) ownership of fifty percent
|
|
20
|
+
(50%) or more of the outstanding shares, or (iii) beneficial ownership of
|
|
21
|
+
such entity.
|
|
22
|
+
|
|
23
|
+
"You" (or "Your") shall mean an individual or Legal Entity exercising
|
|
24
|
+
permissions granted by this License.
|
|
25
|
+
|
|
26
|
+
"Source" form shall mean the preferred form for making modifications,
|
|
27
|
+
including but not limited to software source code, documentation source, and
|
|
28
|
+
configuration files.
|
|
29
|
+
|
|
30
|
+
"Object" form shall mean any form resulting from mechanical transformation
|
|
31
|
+
or translation of a Source form, including but not limited to compiled object
|
|
32
|
+
code, generated documentation, and conversions to other media types.
|
|
33
|
+
|
|
34
|
+
"Work" shall mean the work of authorship, whether in Source or Object form,
|
|
35
|
+
made available under the License, as indicated by a copyright notice that is
|
|
36
|
+
included in or attached to the work (an example is provided in the Appendix
|
|
37
|
+
below).
|
|
38
|
+
|
|
39
|
+
"Derivative Works" shall mean any work, whether in Source or Object form,
|
|
40
|
+
that is based on (or derived from) the Work and for which the editorial
|
|
41
|
+
revisions, annotations, elaborations, or other modifications represent, as a
|
|
42
|
+
whole, an original work of authorship. For the purposes of this License,
|
|
43
|
+
Derivative Works shall not include works that remain separable from, or
|
|
44
|
+
merely link (or bind by name) to the interfaces of, the Work and Derivative
|
|
45
|
+
Works thereof.
|
|
46
|
+
|
|
47
|
+
"Contribution" shall mean any work of authorship, including the original
|
|
48
|
+
version of the Work and any modifications or additions to that Work or
|
|
49
|
+
Derivative Works thereof, that is intentionally submitted to Licensor for
|
|
50
|
+
inclusion in the Work by the copyright owner or by an individual or Legal
|
|
51
|
+
Entity authorized to submit on behalf of the copyright owner. For the
|
|
52
|
+
purposes of this definition, "submitted" means any form of electronic,
|
|
53
|
+
verbal, or written communication sent to the Licensor or its representatives,
|
|
54
|
+
including but not limited to communication on electronic mailing lists, source
|
|
55
|
+
code control systems, and issue tracking systems that are managed by, or on
|
|
56
|
+
behalf of, the Licensor for the purpose of discussing and improving the Work,
|
|
57
|
+
but excluding communication that is conspicuously marked or otherwise
|
|
58
|
+
designated in writing by the copyright owner as "Not a Contribution."
|
|
59
|
+
|
|
60
|
+
"Contributor" shall mean Licensor and any individual or Legal Entity on
|
|
61
|
+
behalf of whom a Contribution has been received by Licensor and subsequently
|
|
62
|
+
incorporated within the Work.
|
|
63
|
+
|
|
64
|
+
2. Grant of Copyright License. Subject to the terms and conditions of this
|
|
65
|
+
License, each Contributor hereby grants to You a perpetual, worldwide,
|
|
66
|
+
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
|
|
67
|
+
reproduce, prepare Derivative Works of, publicly display, publicly perform,
|
|
68
|
+
sublicense, and distribute the Work and such Derivative Works in Source or
|
|
69
|
+
Object form.
|
|
70
|
+
|
|
71
|
+
3. Grant of Patent License. Subject to the terms and conditions of this
|
|
72
|
+
License, each Contributor hereby grants to You a perpetual, worldwide,
|
|
73
|
+
non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this
|
|
74
|
+
section) patent license to make, have made, use, offer to sell, sell, import,
|
|
75
|
+
and otherwise transfer the Work, where such license applies only to those
|
|
76
|
+
patent claims licensable by such Contributor that are necessarily infringed
|
|
77
|
+
by their Contribution(s) alone or by combination of their Contribution(s) with
|
|
78
|
+
the Work to which such Contribution(s) was submitted. If You institute patent
|
|
79
|
+
litigation against any entity (including a cross-claim or counterclaim in a
|
|
80
|
+
lawsuit) alleging that the Work or a Contribution incorporated within the
|
|
81
|
+
Work constitutes direct or contributory patent infringement, then any patent
|
|
82
|
+
licenses granted to You under this License for that Work shall terminate as
|
|
83
|
+
of the date such litigation is filed.
|
|
84
|
+
|
|
85
|
+
4. Redistribution. You may reproduce and distribute copies of the Work or
|
|
86
|
+
Derivative Works thereof in any medium, with or without modifications, and in
|
|
87
|
+
Source or Object form, provided that You meet the following conditions:
|
|
88
|
+
|
|
89
|
+
(a) You must give any other recipients of the Work or Derivative Works a copy
|
|
90
|
+
of this License; and
|
|
91
|
+
|
|
92
|
+
(b) You must cause any modified files to carry prominent notices stating that
|
|
93
|
+
You changed the files; and
|
|
94
|
+
|
|
95
|
+
(c) You must retain, in the Source form of any Derivative Works that You
|
|
96
|
+
distribute, all copyright, patent, trademark, and attribution notices from
|
|
97
|
+
the Source form of the Work, excluding those notices that do not pertain to
|
|
98
|
+
any part of the Derivative Works; and
|
|
99
|
+
|
|
100
|
+
(d) If the Work includes a "NOTICE" text file as part of its distribution,
|
|
101
|
+
then any Derivative Works that You distribute must include a readable copy of
|
|
102
|
+
the attribution notices contained within such NOTICE file, excluding those
|
|
103
|
+
notices that do not pertain to any part of the Derivative Works, in at least
|
|
104
|
+
one of the following places: within a NOTICE text file distributed as part of
|
|
105
|
+
the Derivative Works; within the Source form or documentation, if provided
|
|
106
|
+
along with the Derivative Works; or, within a display generated by the
|
|
107
|
+
Derivative Works, if and wherever such third-party notices normally appear.
|
|
108
|
+
The contents of the NOTICE file are for informational purposes only and do
|
|
109
|
+
not modify the License. You may add Your own attribution notices within
|
|
110
|
+
Derivative Works that You distribute, alongside or as an addendum to the
|
|
111
|
+
NOTICE text from the Work, provided that such additional attribution notices
|
|
112
|
+
cannot be construed as modifying the License.
|
|
113
|
+
|
|
114
|
+
You may add Your own copyright statement to Your modifications and may
|
|
115
|
+
provide additional or different license terms and conditions for use,
|
|
116
|
+
reproduction, or distribution of Your modifications, or for any such
|
|
117
|
+
Derivative Works as a whole, provided Your use, reproduction, and
|
|
118
|
+
distribution of the Work otherwise complies with the conditions stated in
|
|
119
|
+
this License.
|
|
120
|
+
|
|
121
|
+
5. Submission of Contributions. Unless You explicitly state otherwise, any
|
|
122
|
+
Contribution intentionally submitted for inclusion in the Work by You to the
|
|
123
|
+
Licensor shall be under the terms and conditions of this License, without any
|
|
124
|
+
additional terms or conditions. Notwithstanding the above, nothing herein
|
|
125
|
+
shall supersede or modify the terms of any separate license agreement you may
|
|
126
|
+
have executed with Licensor regarding such Contributions.
|
|
127
|
+
|
|
128
|
+
6. Trademarks. This License does not grant permission to use the trade names,
|
|
129
|
+
trademarks, service marks, or product names of the Licensor, except as
|
|
130
|
+
required for reasonable and customary use in describing the origin of the Work
|
|
131
|
+
and reproducing the content of the NOTICE file.
|
|
132
|
+
|
|
133
|
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
|
|
134
|
+
writing, Licensor provides the Work (and each Contributor provides its
|
|
135
|
+
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
136
|
+
KIND, either express or implied, including, without limitation, any warranties
|
|
137
|
+
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
138
|
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
139
|
+
appropriateness of using or redistributing the Work and assume any risks
|
|
140
|
+
associated with Your exercise of permissions under this License.
|
|
141
|
+
|
|
142
|
+
8. Limitation of Liability. In no event and under no legal theory, whether in
|
|
143
|
+
tort (including negligence), contract, or otherwise, unless required by
|
|
144
|
+
applicable law (such as deliberate and grossly negligent acts) or agreed to
|
|
145
|
+
in writing, shall any Contributor be liable to You for damages, including any
|
|
146
|
+
direct, indirect, special, incidental, or consequential damages of any
|
|
147
|
+
character arising as a result of this License or out of the use or inability
|
|
148
|
+
to use the Work (including but not limited to damages for loss of goodwill,
|
|
149
|
+
work stoppage, computer failure or malfunction, or any and all other
|
|
150
|
+
commercial damages or losses), even if such Contributor has been advised of
|
|
151
|
+
the possibility of such damages.
|
|
152
|
+
|
|
153
|
+
9. Accepting Warranty or Additional Liability. While redistributing the Work
|
|
154
|
+
or Derivative Works thereof, You may choose to offer, and charge a fee for,
|
|
155
|
+
acceptance of support, warranty, indemnity, or other liability obligations
|
|
156
|
+
and/or rights consistent with this License. However, in accepting such
|
|
157
|
+
obligations, You may act only on Your own behalf and on Your sole
|
|
158
|
+
responsibility, not on behalf of any other Contributor, and only if You agree
|
|
159
|
+
to indemnify, defend, and hold each Contributor harmless for any liability
|
|
160
|
+
incurred by, or claims asserted against, such Contributor by reason of your
|
|
161
|
+
accepting any such warranty or additional liability.
|
|
162
|
+
|
|
163
|
+
END OF TERMS AND CONDITIONS
|
|
164
|
+
|
|
165
|
+
Copyright 2026 Gradient
|
|
166
|
+
|
|
167
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
168
|
+
you may not use this file except in compliance with the License.
|
|
169
|
+
You may obtain a copy of the License at
|
|
170
|
+
|
|
171
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
172
|
+
|
|
173
|
+
Unless required by applicable law or agreed to in writing, software
|
|
174
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
175
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
176
|
+
See the License for the specific language governing permissions and
|
|
177
|
+
limitations under the License.
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: logits-sdk
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: The Logits Python SDK facade over the official Tinker SDK (import as `logits`)
|
|
5
|
+
Project-URL: Homepage, https://github.com/GradientHQ/logits
|
|
6
|
+
Project-URL: Repository, https://github.com/GradientHQ/logits
|
|
7
|
+
Project-URL: Issues, https://github.com/GradientHQ/logits/issues
|
|
8
|
+
Project-URL: Documentation, https://github.com/GradientHQ/logits#readme
|
|
9
|
+
Author-email: Gradient <support@gradient.ai>
|
|
10
|
+
License-Expression: Apache-2.0
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Keywords: logits,logits-sdk,sampling,tinker,training
|
|
13
|
+
Classifier: Development Status :: 3 - Alpha
|
|
14
|
+
Classifier: Intended Audience :: Developers
|
|
15
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Typing :: Typed
|
|
21
|
+
Requires-Python: >=3.11
|
|
22
|
+
Requires-Dist: tinker<0.22,>=0.21
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# Logits Python SDK
|
|
26
|
+
|
|
27
|
+
`logits-sdk` is the official Python SDK for the Logits platform. Install it
|
|
28
|
+
under the distribution name `logits-sdk`; application code imports it as
|
|
29
|
+
`logits`.
|
|
30
|
+
|
|
31
|
+
## Installation
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
pip install logits-sdk
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Authentication
|
|
38
|
+
|
|
39
|
+
Set a Logits API key before creating clients:
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
export LOGITS_API_KEY="your-api-key"
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
For non-default deployments, set a base URL:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
export LOGITS_BASE_URL="https://api.example.com"
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Usage
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
import logits
|
|
55
|
+
|
|
56
|
+
service_client = logits.ServiceClient()
|
|
57
|
+
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B")
|
|
58
|
+
future = sampling_client.sample(
|
|
59
|
+
prompt=logits.ModelInput.from_ints([1, 2, 3]),
|
|
60
|
+
num_samples=1,
|
|
61
|
+
sampling_params=logits.SamplingParams(max_tokens=32),
|
|
62
|
+
)
|
|
63
|
+
result = future.result()
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Async usage:
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
import logits
|
|
70
|
+
|
|
71
|
+
service_client = logits.ServiceClient()
|
|
72
|
+
sampling_client = await service_client.create_sampling_client_async(
|
|
73
|
+
base_model="Qwen/Qwen3-8B"
|
|
74
|
+
)
|
|
75
|
+
result = await sampling_client.sample_async(
|
|
76
|
+
prompt=logits.ModelInput.from_ints([1, 2, 3]),
|
|
77
|
+
num_samples=1,
|
|
78
|
+
sampling_params=logits.SamplingParams(max_tokens=32),
|
|
79
|
+
)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
Close the underlying holder when a long-running process no longer needs the client:
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
service_client.holder.close()
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
## Development
|
|
89
|
+
|
|
90
|
+
Install the package and test dependencies:
|
|
91
|
+
|
|
92
|
+
```bash
|
|
93
|
+
python -m pip install -e .
|
|
94
|
+
python -m pip install pytest pytest-timeout respx
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
Run the test suite:
|
|
98
|
+
|
|
99
|
+
```bash
|
|
100
|
+
pytest
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
Build and validate release artifacts:
|
|
104
|
+
|
|
105
|
+
```bash
|
|
106
|
+
python -m pip install build twine
|
|
107
|
+
python -m build
|
|
108
|
+
python -m twine check dist/*
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
## Release Checklist
|
|
112
|
+
|
|
113
|
+
- Bump both `pyproject.toml` and `src/logits/_version.py`.
|
|
114
|
+
- Run `pytest`, `python -m build`, and `python -m twine check dist/*`.
|
|
115
|
+
- Confirm GitHub Actions CI is green on the release commit.
|
|
116
|
+
- Create a GitHub release from the version tag to trigger the publish workflow.
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# Logits Python SDK
|
|
2
|
+
|
|
3
|
+
`logits-sdk` is the official Python SDK for the Logits platform. Install it
|
|
4
|
+
under the distribution name `logits-sdk`; application code imports it as
|
|
5
|
+
`logits`.
|
|
6
|
+
|
|
7
|
+
## Installation
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
pip install logits-sdk
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## Authentication
|
|
14
|
+
|
|
15
|
+
Set a Logits API key before creating clients:
|
|
16
|
+
|
|
17
|
+
```bash
|
|
18
|
+
export LOGITS_API_KEY="your-api-key"
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
For non-default deployments, set a base URL:
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
export LOGITS_BASE_URL="https://api.example.com"
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Usage
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
import logits
|
|
31
|
+
|
|
32
|
+
service_client = logits.ServiceClient()
|
|
33
|
+
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B")
|
|
34
|
+
future = sampling_client.sample(
|
|
35
|
+
prompt=logits.ModelInput.from_ints([1, 2, 3]),
|
|
36
|
+
num_samples=1,
|
|
37
|
+
sampling_params=logits.SamplingParams(max_tokens=32),
|
|
38
|
+
)
|
|
39
|
+
result = future.result()
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Async usage:
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
import logits
|
|
46
|
+
|
|
47
|
+
service_client = logits.ServiceClient()
|
|
48
|
+
sampling_client = await service_client.create_sampling_client_async(
|
|
49
|
+
base_model="Qwen/Qwen3-8B"
|
|
50
|
+
)
|
|
51
|
+
result = await sampling_client.sample_async(
|
|
52
|
+
prompt=logits.ModelInput.from_ints([1, 2, 3]),
|
|
53
|
+
num_samples=1,
|
|
54
|
+
sampling_params=logits.SamplingParams(max_tokens=32),
|
|
55
|
+
)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Close the underlying holder when a long-running process no longer needs the client:
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
service_client.holder.close()
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## Development
|
|
65
|
+
|
|
66
|
+
Install the package and test dependencies:
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
python -m pip install -e .
|
|
70
|
+
python -m pip install pytest pytest-timeout respx
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
Run the test suite:
|
|
74
|
+
|
|
75
|
+
```bash
|
|
76
|
+
pytest
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
Build and validate release artifacts:
|
|
80
|
+
|
|
81
|
+
```bash
|
|
82
|
+
python -m pip install build twine
|
|
83
|
+
python -m build
|
|
84
|
+
python -m twine check dist/*
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
## Release Checklist
|
|
88
|
+
|
|
89
|
+
- Bump both `pyproject.toml` and `src/logits/_version.py`.
|
|
90
|
+
- Run `pytest`, `python -m build`, and `python -m twine check dist/*`.
|
|
91
|
+
- Confirm GitHub Actions CI is green on the release commit.
|
|
92
|
+
- Create a GitHub release from the version tag to trigger the publish workflow.
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "logits-sdk"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "The Logits Python SDK facade over the official Tinker SDK (import as `logits`)"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [{ name = "Gradient", email = "support@gradient.ai" }]
|
|
7
|
+
license = "Apache-2.0"
|
|
8
|
+
requires-python = ">=3.11"
|
|
9
|
+
dependencies = [
|
|
10
|
+
"tinker>=0.21,<0.22",
|
|
11
|
+
]
|
|
12
|
+
keywords = ["logits", "logits-sdk", "tinker", "training", "sampling"]
|
|
13
|
+
classifiers = [
|
|
14
|
+
"Development Status :: 3 - Alpha",
|
|
15
|
+
"Intended Audience :: Developers",
|
|
16
|
+
"License :: OSI Approved :: Apache Software License",
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"Programming Language :: Python :: 3.11",
|
|
19
|
+
"Programming Language :: Python :: 3.12",
|
|
20
|
+
"Programming Language :: Python :: 3.13",
|
|
21
|
+
"Typing :: Typed",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[project.urls]
|
|
25
|
+
Homepage = "https://github.com/GradientHQ/logits"
|
|
26
|
+
Repository = "https://github.com/GradientHQ/logits"
|
|
27
|
+
Issues = "https://github.com/GradientHQ/logits/issues"
|
|
28
|
+
Documentation = "https://github.com/GradientHQ/logits#readme"
|
|
29
|
+
|
|
30
|
+
[dependency-groups]
|
|
31
|
+
dev = [
|
|
32
|
+
"pytest",
|
|
33
|
+
"pytest-timeout",
|
|
34
|
+
"respx",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[build-system]
|
|
38
|
+
requires = ["hatchling"]
|
|
39
|
+
build-backend = "hatchling.build"
|
|
40
|
+
|
|
41
|
+
[tool.hatch.build.targets.wheel]
|
|
42
|
+
packages = ["src/logits"]
|
|
43
|
+
|
|
44
|
+
[tool.pytest.ini_options]
|
|
45
|
+
testpaths = ["tests"]
|
|
46
|
+
python_files = ["test_*.py"]
|
|
47
|
+
timeout = 15
|
|
48
|
+
timeout_method = "thread"
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import typing as _t
|
|
4
|
+
|
|
5
|
+
import tinker as _tinker
|
|
6
|
+
from tinker import * # noqa: F403
|
|
7
|
+
|
|
8
|
+
from . import types
|
|
9
|
+
from ._config import (
|
|
10
|
+
LOGITS_API_KEY_ENV,
|
|
11
|
+
LOGITS_BASE_URL_ENV,
|
|
12
|
+
TINKER_API_KEY_ENV,
|
|
13
|
+
TINKER_BASE_URL_ENV,
|
|
14
|
+
resolve_api_key,
|
|
15
|
+
resolve_base_url,
|
|
16
|
+
)
|
|
17
|
+
from ._service_client import ServiceClient, create_service_client
|
|
18
|
+
from ._training_client import TrainingClient
|
|
19
|
+
from ._version import __title__, __version__
|
|
20
|
+
|
|
21
|
+
LogitsError = TinkerError
|
|
22
|
+
|
|
23
|
+
__backend_sdk__ = "tinker"
|
|
24
|
+
__backend_sdk_version__ = _tinker.__version__
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
*[name for name in _tinker.__all__ if name not in ("ServiceClient", "TrainingClient")],
|
|
28
|
+
"ServiceClient",
|
|
29
|
+
"TrainingClient",
|
|
30
|
+
"create_service_client",
|
|
31
|
+
"LOGITS_API_KEY_ENV",
|
|
32
|
+
"LOGITS_BASE_URL_ENV",
|
|
33
|
+
"TINKER_API_KEY_ENV",
|
|
34
|
+
"TINKER_BASE_URL_ENV",
|
|
35
|
+
"resolve_api_key",
|
|
36
|
+
"resolve_base_url",
|
|
37
|
+
"LogitsError",
|
|
38
|
+
"__title__",
|
|
39
|
+
"__version__",
|
|
40
|
+
"__backend_sdk__",
|
|
41
|
+
"__backend_sdk_version__",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
if not _t.TYPE_CHECKING:
|
|
45
|
+
from tinker import resources as resources
|
|
46
|
+
|
|
47
|
+
__locals = locals()
|
|
48
|
+
for __name in __all__:
|
|
49
|
+
if not __name.startswith("__"):
|
|
50
|
+
try:
|
|
51
|
+
__locals[__name].__module__ = "logits"
|
|
52
|
+
except (TypeError, AttributeError):
|
|
53
|
+
pass
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import tinker.lib._auth_token_provider as _auth_provider
|
|
9
|
+
import tinker.lib.internal_client_holder as _internal_client_holder
|
|
10
|
+
import tinker._client as _tinker_client
|
|
11
|
+
from tinker import types as _tinker_types
|
|
12
|
+
|
|
13
|
+
LOGITS_API_KEY_ENV = "LOGITS_API_KEY"
|
|
14
|
+
LOGITS_BASE_URL_ENV = "LOGITS_BASE_URL"
|
|
15
|
+
TINKER_API_KEY_ENV = "TINKER_API_KEY"
|
|
16
|
+
TINKER_BASE_URL_ENV = "TINKER_BASE_URL"
|
|
17
|
+
API_KEY_HEADER = "X-API-Key"
|
|
18
|
+
|
|
19
|
+
_logger = logging.getLogger("logits")
|
|
20
|
+
|
|
21
|
+
DEFAULT_CLIENT_CONFIG: dict[str, Any] = {
|
|
22
|
+
"pjwt_auth_enabled": False,
|
|
23
|
+
"credential_default_source": "api_key",
|
|
24
|
+
"sample_dispatch_bytes_semaphore_size": 1 << 30,
|
|
25
|
+
"inflight_response_bytes_semaphore_size": 1 << 30,
|
|
26
|
+
"parallel_fwdbwd_chunks": 1,
|
|
27
|
+
"proto_write_fwdbwd": False,
|
|
28
|
+
"billing_exception_max_pause_duration_sec": 0,
|
|
29
|
+
"grpc_target": "",
|
|
30
|
+
"enable_grpc_retrieve_future": False,
|
|
31
|
+
"sample_no_retries": False,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LogitsApiKeyAuthProvider(_auth_provider.AuthTokenProvider):
|
|
36
|
+
"""Tinker-compatible auth provider that accepts Logits API keys."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, api_key: str | None = None) -> None:
|
|
39
|
+
resolved = api_key or _get_first_env(LOGITS_API_KEY_ENV, TINKER_API_KEY_ENV)
|
|
40
|
+
if not resolved:
|
|
41
|
+
raise _auth_provider.TinkerError(
|
|
42
|
+
"The api_key client option must be set either by passing api_key to the client"
|
|
43
|
+
f" or by setting the {LOGITS_API_KEY_ENV} environment variable"
|
|
44
|
+
)
|
|
45
|
+
self._token = resolved
|
|
46
|
+
|
|
47
|
+
async def get_token(self) -> str | None:
|
|
48
|
+
return self._token
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def install_logits_auth_compat() -> None:
|
|
52
|
+
_auth_provider.ApiKeyAuthProvider = LogitsApiKeyAuthProvider
|
|
53
|
+
_internal_client_holder.ApiKeyAuthProvider = LogitsApiKeyAuthProvider
|
|
54
|
+
_tinker_client.ApiKeyAuthProvider = LogitsApiKeyAuthProvider
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
_FETCH_CLIENT_CONFIG_PATCHED = False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def install_client_config_fallback() -> None:
|
|
61
|
+
"""Make `_fetch_client_config` fall back to defaults on transport failure.
|
|
62
|
+
|
|
63
|
+
Some Logits deployments do not implement `/api/v1/client/config` yet. When
|
|
64
|
+
the upstream tinker holder boots, it blocks on that endpoint before any
|
|
65
|
+
other call. Wrap the holder method so a 404 / connect error returns the
|
|
66
|
+
built-in defaults instead of aborting bootstrap. Once the backend ships
|
|
67
|
+
the endpoint, real responses take over automatically.
|
|
68
|
+
"""
|
|
69
|
+
global _FETCH_CLIENT_CONFIG_PATCHED
|
|
70
|
+
if _FETCH_CLIENT_CONFIG_PATCHED:
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
original = _internal_client_holder.InternalClientHolder._fetch_client_config
|
|
74
|
+
|
|
75
|
+
async def _fetch_client_config_with_fallback(
|
|
76
|
+
self: _internal_client_holder.InternalClientHolder,
|
|
77
|
+
auth: _auth_provider.AuthTokenProvider,
|
|
78
|
+
) -> _tinker_types.ClientConfigResponse:
|
|
79
|
+
try:
|
|
80
|
+
return await original(self, auth)
|
|
81
|
+
except Exception as exc: # noqa: BLE001 — narrow below
|
|
82
|
+
if not _is_missing_endpoint_error(exc):
|
|
83
|
+
raise
|
|
84
|
+
_logger.info(
|
|
85
|
+
"logits: /api/v1/client/config unavailable (%s); using built-in defaults",
|
|
86
|
+
_summarize_exc(exc),
|
|
87
|
+
)
|
|
88
|
+
return _tinker_types.ClientConfigResponse.model_validate(DEFAULT_CLIENT_CONFIG)
|
|
89
|
+
|
|
90
|
+
_internal_client_holder.InternalClientHolder._fetch_client_config = ( # type: ignore[method-assign]
|
|
91
|
+
_fetch_client_config_with_fallback
|
|
92
|
+
)
|
|
93
|
+
_FETCH_CLIENT_CONFIG_PATCHED = True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _is_missing_endpoint_error(exc: BaseException) -> bool:
|
|
97
|
+
# httpx.HTTPStatusError exposes `.response`; tinker wraps 4xx into TinkerError.
|
|
98
|
+
response = getattr(exc, "response", None)
|
|
99
|
+
if response is not None and getattr(response, "status_code", None) == 404:
|
|
100
|
+
return True
|
|
101
|
+
# Tinker raises NotFoundError (subclass of TinkerError) for 404 responses.
|
|
102
|
+
if type(exc).__name__ == "NotFoundError":
|
|
103
|
+
return True
|
|
104
|
+
# Connection-level failures (DNS, refused, TLS) — treat as endpoint missing.
|
|
105
|
+
if isinstance(exc, (ConnectionError, OSError)):
|
|
106
|
+
return True
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _summarize_exc(exc: BaseException) -> str:
|
|
111
|
+
return f"{type(exc).__name__}: {exc}"[:200]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _get_first_env(*names: str) -> str | None:
|
|
115
|
+
for name in names:
|
|
116
|
+
value = os.environ.get(name)
|
|
117
|
+
if value:
|
|
118
|
+
return value
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def resolve_api_key(
|
|
123
|
+
api_key: str | None = None,
|
|
124
|
+
*,
|
|
125
|
+
default_headers: Mapping[str, str] | None = None,
|
|
126
|
+
) -> str | None:
|
|
127
|
+
if api_key is not None:
|
|
128
|
+
return api_key
|
|
129
|
+
if default_headers is not None and default_headers.get(API_KEY_HEADER):
|
|
130
|
+
return default_headers[API_KEY_HEADER]
|
|
131
|
+
return _get_first_env(LOGITS_API_KEY_ENV, TINKER_API_KEY_ENV)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def resolve_base_url(base_url: str | None = None) -> str | None:
|
|
135
|
+
if base_url is not None:
|
|
136
|
+
return base_url
|
|
137
|
+
return _get_first_env(LOGITS_BASE_URL_ENV, TINKER_BASE_URL_ENV)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def resolve_service_client_kwargs(
|
|
141
|
+
*,
|
|
142
|
+
api_key: str | None = None,
|
|
143
|
+
base_url: str | None = None,
|
|
144
|
+
default_headers: Mapping[str, str] | None = None,
|
|
145
|
+
) -> dict[str, Any]:
|
|
146
|
+
install_logits_auth_compat()
|
|
147
|
+
install_client_config_fallback()
|
|
148
|
+
headers = dict(default_headers or {})
|
|
149
|
+
resolved_api_key = resolve_api_key(api_key, default_headers=headers)
|
|
150
|
+
resolved_base_url = resolve_base_url(base_url)
|
|
151
|
+
|
|
152
|
+
kwargs: dict[str, Any] = {}
|
|
153
|
+
if resolved_base_url is not None:
|
|
154
|
+
kwargs["base_url"] = resolved_base_url
|
|
155
|
+
|
|
156
|
+
if resolved_api_key is None:
|
|
157
|
+
if headers:
|
|
158
|
+
kwargs["default_headers"] = headers
|
|
159
|
+
return kwargs
|
|
160
|
+
|
|
161
|
+
kwargs["api_key"] = resolved_api_key
|
|
162
|
+
|
|
163
|
+
if headers:
|
|
164
|
+
kwargs["default_headers"] = headers
|
|
165
|
+
return kwargs
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
LOGITS_URI_PREFIX = "logits://"
|
|
4
|
+
TINKER_URI_PREFIX = "tinker://"
|
|
5
|
+
|
|
6
|
+
_VALID_MODEL_PATH_PREFIXES = (LOGITS_URI_PREFIX, TINKER_URI_PREFIX)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def is_valid_model_path(path: str | None) -> bool:
|
|
10
|
+
"""Whether the URI looks like a supported weights path.
|
|
11
|
+
|
|
12
|
+
Tinker SDK historically required ``tinker://``. Newer backends may return
|
|
13
|
+
``logits://`` instead, so we accept both prefixes at the SDK boundary.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
if path is None:
|
|
17
|
+
return False
|
|
18
|
+
return path.startswith(_VALID_MODEL_PATH_PREFIXES)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def normalize_tinker_path(path: str) -> str:
|
|
22
|
+
"""Convert logits://<...> to tinker://<...> for local parsing only."""
|
|
23
|
+
|
|
24
|
+
if path.startswith(LOGITS_URI_PREFIX):
|
|
25
|
+
return TINKER_URI_PREFIX + path[len(LOGITS_URI_PREFIX) :]
|
|
26
|
+
return path
|
|
27
|
+
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from concurrent.futures import Future as ConcurrentFuture
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from tinker import types
|
|
7
|
+
from tinker.lib.public_interfaces.rest_client import RestClient as TinkerRestClient
|
|
8
|
+
|
|
9
|
+
from ._path_compat import normalize_tinker_path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RestClient(TinkerRestClient):
|
|
13
|
+
"""Logits-compatible RestClient.
|
|
14
|
+
|
|
15
|
+
Upstream Tinker RestClient's `*_by_tinker_path` helpers parse the URI using
|
|
16
|
+
`ParsedCheckpointTinkerPath.from_tinker_path`, which historically only
|
|
17
|
+
accepted `tinker://`. Backends may now return `logits://` instead, so we
|
|
18
|
+
normalize only for local parsing while still treating the rest of the API
|
|
19
|
+
parameters as opaque.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def get_training_run_by_tinker_path(
|
|
23
|
+
self, tinker_path: str, access_scope: Literal["owned", "accessible"] = "owned"
|
|
24
|
+
) -> ConcurrentFuture[types.TrainingRun]:
|
|
25
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
26
|
+
return self.get_training_run(parsed.training_run_id, access_scope=access_scope)
|
|
27
|
+
|
|
28
|
+
async def get_training_run_by_tinker_path_async(
|
|
29
|
+
self, tinker_path: str, access_scope: Literal["owned", "accessible"] = "owned"
|
|
30
|
+
) -> types.TrainingRun:
|
|
31
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
32
|
+
return await self.get_training_run_async(parsed.training_run_id, access_scope=access_scope)
|
|
33
|
+
|
|
34
|
+
def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
|
35
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
36
|
+
return self._delete_checkpoint_submit(parsed.training_run_id, parsed.checkpoint_id).future()
|
|
37
|
+
|
|
38
|
+
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
|
39
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
40
|
+
await self._delete_checkpoint_submit(parsed.training_run_id, parsed.checkpoint_id)
|
|
41
|
+
|
|
42
|
+
def get_checkpoint_archive_url_from_tinker_path(
|
|
43
|
+
self, tinker_path: str
|
|
44
|
+
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
|
|
45
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
46
|
+
return self._get_checkpoint_archive_url_submit(parsed.training_run_id, parsed.checkpoint_id).future()
|
|
47
|
+
|
|
48
|
+
async def get_checkpoint_archive_url_from_tinker_path_async(
|
|
49
|
+
self, tinker_path: str
|
|
50
|
+
) -> types.CheckpointArchiveUrlResponse:
|
|
51
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
52
|
+
return await self._get_checkpoint_archive_url_submit(parsed.training_run_id, parsed.checkpoint_id)
|
|
53
|
+
|
|
54
|
+
def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
|
55
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
56
|
+
return self._publish_checkpoint_submit(parsed.training_run_id, parsed.checkpoint_id).future()
|
|
57
|
+
|
|
58
|
+
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
|
59
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
60
|
+
await self._publish_checkpoint_submit(parsed.training_run_id, parsed.checkpoint_id)
|
|
61
|
+
|
|
62
|
+
def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
|
63
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
64
|
+
return self._unpublish_checkpoint_submit(parsed.training_run_id, parsed.checkpoint_id).future()
|
|
65
|
+
|
|
66
|
+
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
|
67
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
68
|
+
await self._unpublish_checkpoint_submit(parsed.training_run_id, parsed.checkpoint_id)
|
|
69
|
+
|
|
70
|
+
def set_checkpoint_ttl_from_tinker_path(
|
|
71
|
+
self, tinker_path: str, ttl_seconds: int | None
|
|
72
|
+
) -> ConcurrentFuture[None]:
|
|
73
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
74
|
+
return self._set_checkpoint_ttl_submit(
|
|
75
|
+
parsed.training_run_id, parsed.checkpoint_id, ttl_seconds
|
|
76
|
+
).future()
|
|
77
|
+
|
|
78
|
+
async def set_checkpoint_ttl_from_tinker_path_async(
|
|
79
|
+
self, tinker_path: str, ttl_seconds: int | None
|
|
80
|
+
) -> None:
|
|
81
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(normalize_tinker_path(tinker_path))
|
|
82
|
+
await self._set_checkpoint_ttl_submit(parsed.training_run_id, parsed.checkpoint_id, ttl_seconds)
|
|
83
|
+
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import tinker
|
|
6
|
+
from tinker import SamplingClient, types
|
|
7
|
+
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
|
8
|
+
from tinker.lib.retry_handler import RetryConfig
|
|
9
|
+
|
|
10
|
+
from ._config import resolve_service_client_kwargs
|
|
11
|
+
from ._path_compat import is_valid_model_path
|
|
12
|
+
from ._training_client import TrainingClient as LogitsTrainingClient
|
|
13
|
+
from ._rest_client import RestClient as LogitsRestClient
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ServiceClient(tinker.ServiceClient):
|
|
17
|
+
"""Logits-first facade over ``tinker.ServiceClient``."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
user_metadata: dict[str, str] | None = None,
|
|
22
|
+
project_id: str | None = None,
|
|
23
|
+
**kwargs: Any,
|
|
24
|
+
) -> None:
|
|
25
|
+
base_url = kwargs.pop("base_url", None)
|
|
26
|
+
api_key = kwargs.pop("api_key", None)
|
|
27
|
+
default_headers = kwargs.pop("default_headers", None)
|
|
28
|
+
translated_kwargs = resolve_service_client_kwargs(
|
|
29
|
+
api_key=api_key,
|
|
30
|
+
base_url=base_url,
|
|
31
|
+
default_headers=default_headers,
|
|
32
|
+
)
|
|
33
|
+
super().__init__(
|
|
34
|
+
user_metadata=user_metadata,
|
|
35
|
+
project_id=project_id,
|
|
36
|
+
**translated_kwargs,
|
|
37
|
+
**kwargs,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _wrap_training_client(training_client: tinker.TrainingClient) -> LogitsTrainingClient:
|
|
42
|
+
# tinker.TrainingClient stores the model seq id in a private field.
|
|
43
|
+
# We only need it to construct an equivalent wrapper instance.
|
|
44
|
+
return LogitsTrainingClient(
|
|
45
|
+
training_client.holder,
|
|
46
|
+
getattr(training_client, "_training_client_id"),
|
|
47
|
+
training_client.model_id,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def create_sampling_client(
|
|
51
|
+
self,
|
|
52
|
+
model_path: str | None = None,
|
|
53
|
+
base_model: str | None = None,
|
|
54
|
+
retry_config: RetryConfig | None = None,
|
|
55
|
+
) -> SamplingClient:
|
|
56
|
+
if model_path is None and base_model is None:
|
|
57
|
+
raise ValueError("Either model_path or base_model must be provided")
|
|
58
|
+
if model_path is not None and not is_valid_model_path(model_path):
|
|
59
|
+
raise ValueError("model_path must start with 'tinker://' or 'logits://'")
|
|
60
|
+
|
|
61
|
+
async def _create_sampling_client_async() -> SamplingClient:
|
|
62
|
+
assert self.holder._sampling_client_counter is not None
|
|
63
|
+
sampling_session_seq_id = self.holder._sampling_client_counter
|
|
64
|
+
self.holder._sampling_client_counter += 1
|
|
65
|
+
|
|
66
|
+
with self.holder.aclient(ClientConnectionPoolType.SESSION) as client:
|
|
67
|
+
request = types.CreateSamplingSessionRequest(
|
|
68
|
+
session_id=self.holder._session_id,
|
|
69
|
+
sampling_session_seq_id=sampling_session_seq_id,
|
|
70
|
+
model_path=model_path,
|
|
71
|
+
base_model=base_model,
|
|
72
|
+
)
|
|
73
|
+
result = await client.service.create_sampling_session(request=request)
|
|
74
|
+
|
|
75
|
+
return SamplingClient(
|
|
76
|
+
self.holder,
|
|
77
|
+
sampling_session_id=result.sampling_session_id,
|
|
78
|
+
retry_config=retry_config,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return self.holder.run_coroutine_threadsafe(_create_sampling_client_async()).result()
|
|
82
|
+
|
|
83
|
+
async def create_sampling_client_async(
|
|
84
|
+
self,
|
|
85
|
+
model_path: str | None = None,
|
|
86
|
+
base_model: str | None = None,
|
|
87
|
+
retry_config: RetryConfig | None = None,
|
|
88
|
+
) -> SamplingClient:
|
|
89
|
+
if model_path is None and base_model is None:
|
|
90
|
+
raise ValueError("Either model_path or base_model must be provided")
|
|
91
|
+
if model_path is not None and not is_valid_model_path(model_path):
|
|
92
|
+
raise ValueError("model_path must start with 'tinker://' or 'logits://'")
|
|
93
|
+
|
|
94
|
+
async def _create_sampling_client_async_inner() -> SamplingClient:
|
|
95
|
+
assert self.holder._sampling_client_counter is not None
|
|
96
|
+
sampling_session_seq_id = self.holder._sampling_client_counter
|
|
97
|
+
self.holder._sampling_client_counter += 1
|
|
98
|
+
|
|
99
|
+
with self.holder.aclient(ClientConnectionPoolType.SESSION) as client:
|
|
100
|
+
request = types.CreateSamplingSessionRequest(
|
|
101
|
+
session_id=self.holder._session_id,
|
|
102
|
+
sampling_session_seq_id=sampling_session_seq_id,
|
|
103
|
+
model_path=model_path,
|
|
104
|
+
base_model=base_model,
|
|
105
|
+
)
|
|
106
|
+
result = await client.service.create_sampling_session(request=request)
|
|
107
|
+
|
|
108
|
+
return SamplingClient(
|
|
109
|
+
self.holder,
|
|
110
|
+
sampling_session_id=result.sampling_session_id,
|
|
111
|
+
retry_config=retry_config,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return await self.holder.run_coroutine_threadsafe(_create_sampling_client_async_inner())
|
|
115
|
+
|
|
116
|
+
def create_lora_training_client(
|
|
117
|
+
self,
|
|
118
|
+
base_model: str,
|
|
119
|
+
rank: int = 32,
|
|
120
|
+
seed: int | None = None,
|
|
121
|
+
train_mlp: bool = True,
|
|
122
|
+
train_attn: bool = True,
|
|
123
|
+
train_unembed: bool = True,
|
|
124
|
+
user_metadata: dict[str, str] | None = None,
|
|
125
|
+
) -> LogitsTrainingClient:
|
|
126
|
+
training_client = super().create_lora_training_client(
|
|
127
|
+
base_model=base_model,
|
|
128
|
+
rank=rank,
|
|
129
|
+
seed=seed,
|
|
130
|
+
train_mlp=train_mlp,
|
|
131
|
+
train_attn=train_attn,
|
|
132
|
+
train_unembed=train_unembed,
|
|
133
|
+
user_metadata=user_metadata,
|
|
134
|
+
)
|
|
135
|
+
return self._wrap_training_client(training_client)
|
|
136
|
+
|
|
137
|
+
async def create_lora_training_client_async(
|
|
138
|
+
self,
|
|
139
|
+
base_model: str,
|
|
140
|
+
rank: int = 32,
|
|
141
|
+
seed: int | None = None,
|
|
142
|
+
train_mlp: bool = True,
|
|
143
|
+
train_attn: bool = True,
|
|
144
|
+
train_unembed: bool = True,
|
|
145
|
+
user_metadata: dict[str, str] | None = None,
|
|
146
|
+
) -> LogitsTrainingClient:
|
|
147
|
+
training_client = await super().create_lora_training_client_async(
|
|
148
|
+
base_model=base_model,
|
|
149
|
+
rank=rank,
|
|
150
|
+
seed=seed,
|
|
151
|
+
train_mlp=train_mlp,
|
|
152
|
+
train_attn=train_attn,
|
|
153
|
+
train_unembed=train_unembed,
|
|
154
|
+
user_metadata=user_metadata,
|
|
155
|
+
)
|
|
156
|
+
return self._wrap_training_client(training_client)
|
|
157
|
+
|
|
158
|
+
def create_training_client_from_state(
|
|
159
|
+
self,
|
|
160
|
+
path: str,
|
|
161
|
+
user_metadata: dict[str, str] | None = None,
|
|
162
|
+
weights_access_token: str | None = None,
|
|
163
|
+
) -> LogitsTrainingClient:
|
|
164
|
+
training_client = super().create_training_client_from_state(
|
|
165
|
+
path,
|
|
166
|
+
user_metadata=user_metadata,
|
|
167
|
+
weights_access_token=weights_access_token,
|
|
168
|
+
)
|
|
169
|
+
return self._wrap_training_client(training_client)
|
|
170
|
+
|
|
171
|
+
async def create_training_client_from_state_async(
|
|
172
|
+
self,
|
|
173
|
+
path: str,
|
|
174
|
+
user_metadata: dict[str, str] | None = None,
|
|
175
|
+
weights_access_token: str | None = None,
|
|
176
|
+
) -> LogitsTrainingClient:
|
|
177
|
+
training_client = await super().create_training_client_from_state_async(
|
|
178
|
+
path,
|
|
179
|
+
user_metadata=user_metadata,
|
|
180
|
+
weights_access_token=weights_access_token,
|
|
181
|
+
)
|
|
182
|
+
return self._wrap_training_client(training_client)
|
|
183
|
+
|
|
184
|
+
def create_training_client_from_state_with_optimizer(
|
|
185
|
+
self,
|
|
186
|
+
path: str,
|
|
187
|
+
user_metadata: dict[str, str] | None = None,
|
|
188
|
+
weights_access_token: str | None = None,
|
|
189
|
+
) -> LogitsTrainingClient:
|
|
190
|
+
training_client = super().create_training_client_from_state_with_optimizer(
|
|
191
|
+
path,
|
|
192
|
+
user_metadata=user_metadata,
|
|
193
|
+
weights_access_token=weights_access_token,
|
|
194
|
+
)
|
|
195
|
+
return self._wrap_training_client(training_client)
|
|
196
|
+
|
|
197
|
+
async def create_training_client_from_state_with_optimizer_async(
|
|
198
|
+
self,
|
|
199
|
+
path: str,
|
|
200
|
+
user_metadata: dict[str, str] | None = None,
|
|
201
|
+
weights_access_token: str | None = None,
|
|
202
|
+
) -> LogitsTrainingClient:
|
|
203
|
+
training_client = await super().create_training_client_from_state_with_optimizer_async(
|
|
204
|
+
path,
|
|
205
|
+
user_metadata=user_metadata,
|
|
206
|
+
weights_access_token=weights_access_token,
|
|
207
|
+
)
|
|
208
|
+
return self._wrap_training_client(training_client)
|
|
209
|
+
|
|
210
|
+
def create_rest_client(self) -> LogitsRestClient:
|
|
211
|
+
return LogitsRestClient(self.holder)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def create_service_client(**kwargs: Any) -> ServiceClient:
|
|
215
|
+
return ServiceClient(**kwargs)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import tinker
|
|
4
|
+
from tinker import SamplingClient, types
|
|
5
|
+
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
|
6
|
+
from tinker.lib.retry_handler import RetryConfig
|
|
7
|
+
|
|
8
|
+
from ._path_compat import is_valid_model_path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TrainingClient(tinker.TrainingClient):
|
|
12
|
+
"""Logits-compatible TrainingClient.
|
|
13
|
+
|
|
14
|
+
The upstream Tinker SDK validates that `model_path` starts with `tinker://`
|
|
15
|
+
when creating a SamplingClient. Some backends now return `logits://`
|
|
16
|
+
checkpoints, so we accept both prefixes and pass the URI through unchanged.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def create_sampling_client(
|
|
20
|
+
self, model_path: str, retry_config: RetryConfig | None = None
|
|
21
|
+
) -> SamplingClient:
|
|
22
|
+
if not is_valid_model_path(model_path):
|
|
23
|
+
raise ValueError("model_path must start with 'tinker://' or 'logits://'")
|
|
24
|
+
|
|
25
|
+
async def _create_sampling_client_async() -> SamplingClient:
|
|
26
|
+
# Create sampling session without the tinker:// prefix check.
|
|
27
|
+
assert self.holder._sampling_client_counter is not None
|
|
28
|
+
sampling_session_seq_id = self.holder._sampling_client_counter
|
|
29
|
+
self.holder._sampling_client_counter += 1
|
|
30
|
+
|
|
31
|
+
with self.holder.aclient(ClientConnectionPoolType.SESSION) as client:
|
|
32
|
+
request = types.CreateSamplingSessionRequest(
|
|
33
|
+
session_id=self.holder._session_id,
|
|
34
|
+
sampling_session_seq_id=sampling_session_seq_id,
|
|
35
|
+
model_path=model_path,
|
|
36
|
+
base_model=None,
|
|
37
|
+
)
|
|
38
|
+
result = await client.service.create_sampling_session(request=request)
|
|
39
|
+
|
|
40
|
+
return SamplingClient(
|
|
41
|
+
self.holder,
|
|
42
|
+
sampling_session_id=result.sampling_session_id,
|
|
43
|
+
retry_config=retry_config,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return self.holder.run_coroutine_threadsafe(_create_sampling_client_async()).result()
|
|
47
|
+
|
|
48
|
+
async def create_sampling_client_async(
|
|
49
|
+
self, model_path: str, retry_config: RetryConfig | None = None
|
|
50
|
+
) -> SamplingClient:
|
|
51
|
+
if not is_valid_model_path(model_path):
|
|
52
|
+
raise ValueError("model_path must start with 'tinker://' or 'logits://'")
|
|
53
|
+
|
|
54
|
+
async def _create_sampling_client_async_inner() -> SamplingClient:
|
|
55
|
+
assert self.holder._sampling_client_counter is not None
|
|
56
|
+
sampling_session_seq_id = self.holder._sampling_client_counter
|
|
57
|
+
self.holder._sampling_client_counter += 1
|
|
58
|
+
|
|
59
|
+
with self.holder.aclient(ClientConnectionPoolType.SESSION) as client:
|
|
60
|
+
request = types.CreateSamplingSessionRequest(
|
|
61
|
+
session_id=self.holder._session_id,
|
|
62
|
+
sampling_session_seq_id=sampling_session_seq_id,
|
|
63
|
+
model_path=model_path,
|
|
64
|
+
base_model=None,
|
|
65
|
+
)
|
|
66
|
+
result = await client.service.create_sampling_session(request=request)
|
|
67
|
+
|
|
68
|
+
return SamplingClient(
|
|
69
|
+
self.holder,
|
|
70
|
+
sampling_session_id=result.sampling_session_id,
|
|
71
|
+
retry_config=retry_config,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return await self.holder.run_coroutine_threadsafe(_create_sampling_client_async_inner())
|
|
75
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import pkgutil
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
import tinker.types as _tinker_types
|
|
8
|
+
from tinker.types import * # noqa: F403
|
|
9
|
+
|
|
10
|
+
for module_info in pkgutil.walk_packages(_tinker_types.__path__, prefix="tinker.types."):
|
|
11
|
+
tinker_module_name = module_info.name
|
|
12
|
+
logits_module_name = tinker_module_name.replace("tinker.", "logits.", 1)
|
|
13
|
+
sys.modules[logits_module_name] = importlib.import_module(tinker_module_name)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
name
|
|
17
|
+
for name in globals()
|
|
18
|
+
if not name.startswith("_")
|
|
19
|
+
and name not in {"importlib", "pkgutil", "sys", "module_info", "tinker_module_name", "logits_module_name"}
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
__locals = locals()
|
|
23
|
+
for __name in __all__:
|
|
24
|
+
try:
|
|
25
|
+
__locals[__name].__module__ = "logits.types"
|
|
26
|
+
except (TypeError, AttributeError):
|
|
27
|
+
pass
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
import pytest
|
|
6
|
+
from respx import MockRouter
|
|
7
|
+
|
|
8
|
+
import logits
|
|
9
|
+
|
|
10
|
+
BASE_URL = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def mock_service_client_bootstrap(respx_mock: MockRouter) -> httpx.Response:
|
|
14
|
+
respx_mock.post("/api/v1/client/config").mock(
|
|
15
|
+
return_value=httpx.Response(200, json={})
|
|
16
|
+
)
|
|
17
|
+
create_session_route = respx_mock.post("/api/v1/create_session").mock(
|
|
18
|
+
return_value=httpx.Response(200, json={"session_id": "test-session-id"})
|
|
19
|
+
)
|
|
20
|
+
return create_session_route
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.mark.respx(base_url=BASE_URL)
|
|
24
|
+
def test_service_client_uses_non_tml_logits_key_as_header(respx_mock: MockRouter) -> None:
|
|
25
|
+
create_session_route = mock_service_client_bootstrap(respx_mock)
|
|
26
|
+
|
|
27
|
+
service_client = logits.ServiceClient(base_url=BASE_URL, api_key="logits-test-key")
|
|
28
|
+
service_client.holder.close()
|
|
29
|
+
|
|
30
|
+
assert create_session_route.called
|
|
31
|
+
assert create_session_route.calls[0].request.headers["X-API-Key"] == "logits-test-key"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.mark.respx(base_url=BASE_URL)
|
|
35
|
+
def test_service_client_prefers_logits_env_vars(
|
|
36
|
+
respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
|
|
37
|
+
) -> None:
|
|
38
|
+
create_session_route = mock_service_client_bootstrap(respx_mock)
|
|
39
|
+
monkeypatch.setenv("LOGITS_API_KEY", "logits-env-key")
|
|
40
|
+
monkeypatch.setenv("LOGITS_BASE_URL", BASE_URL)
|
|
41
|
+
monkeypatch.setenv("TINKER_API_KEY", "tml-tinker-fallback")
|
|
42
|
+
monkeypatch.setenv("TINKER_BASE_URL", "https://example.invalid")
|
|
43
|
+
|
|
44
|
+
service_client = logits.ServiceClient()
|
|
45
|
+
service_client.holder.close()
|
|
46
|
+
|
|
47
|
+
assert create_session_route.called
|
|
48
|
+
request = create_session_route.calls[0].request
|
|
49
|
+
assert str(request.url).startswith(f"{BASE_URL}/api/v1/create_session")
|
|
50
|
+
assert request.headers["X-API-Key"] == "logits-env-key"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.mark.respx(base_url=BASE_URL)
|
|
54
|
+
def test_create_service_client_returns_logits_service_client(respx_mock: MockRouter) -> None:
|
|
55
|
+
mock_service_client_bootstrap(respx_mock)
|
|
56
|
+
|
|
57
|
+
service_client = logits.create_service_client(base_url=BASE_URL, api_key="tml-direct-key")
|
|
58
|
+
service_client.holder.close()
|
|
59
|
+
|
|
60
|
+
assert isinstance(service_client, logits.ServiceClient)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@pytest.mark.respx(base_url=BASE_URL)
|
|
64
|
+
def test_service_client_falls_back_when_client_config_missing(respx_mock: MockRouter) -> None:
|
|
65
|
+
"""Backends without /api/v1/client/config must still bootstrap."""
|
|
66
|
+
config_route = respx_mock.post("/api/v1/client/config").mock(
|
|
67
|
+
return_value=httpx.Response(404, text="404 page not found")
|
|
68
|
+
)
|
|
69
|
+
create_session_route = respx_mock.post("/api/v1/create_session").mock(
|
|
70
|
+
return_value=httpx.Response(200, json={"session_id": "fallback-session"})
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
service_client = logits.ServiceClient(base_url=BASE_URL, api_key="logits-test-key")
|
|
74
|
+
try:
|
|
75
|
+
assert config_route.called
|
|
76
|
+
assert create_session_route.called
|
|
77
|
+
assert service_client.holder._session_id == "fallback-session"
|
|
78
|
+
finally:
|
|
79
|
+
service_client.holder.close()
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from logits.types import LossFnType, TensorData
|
|
4
|
+
from logits.types.tensor_data import TensorData as TensorDataModule
|
|
5
|
+
from tinker.types.tensor_data import TensorData as TinkerTensorData
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_types_package_reexports_tinker_types() -> None:
|
|
9
|
+
assert TensorData is TinkerTensorData
|
|
10
|
+
assert TensorDataModule is TinkerTensorData
|
|
11
|
+
assert LossFnType.__module__ == "logits.types"
|