opengradient 0.3.24__py3-none-any.whl → 0.3.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +125 -98
- opengradient/account.py +6 -4
- opengradient/cli.py +151 -154
- opengradient/client.py +300 -362
- opengradient/defaults.py +7 -7
- opengradient/exceptions.py +25 -0
- opengradient/llm/__init__.py +7 -10
- opengradient/llm/og_langchain.py +34 -51
- opengradient/llm/og_openai.py +54 -61
- opengradient/mltools/__init__.py +2 -7
- opengradient/mltools/model_tool.py +20 -26
- opengradient/proto/infer_pb2.py +24 -29
- opengradient/proto/infer_pb2_grpc.py +95 -86
- opengradient/types.py +39 -35
- opengradient/utils.py +30 -31
- {opengradient-0.3.24.dist-info → opengradient-0.3.25.dist-info}/METADATA +1 -1
- opengradient-0.3.25.dist-info/RECORD +26 -0
- opengradient-0.3.24.dist-info/RECORD +0 -26
- {opengradient-0.3.24.dist-info → opengradient-0.3.25.dist-info}/LICENSE +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.25.dist-info}/WHEEL +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.25.dist-info}/entry_points.txt +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.25.dist-info}/top_level.txt +0 -0
opengradient/cli.py
CHANGED
|
@@ -3,14 +3,10 @@ import json
|
|
|
3
3
|
import logging
|
|
4
4
|
import webbrowser
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
-
from enum import Enum
|
|
8
|
-
from . import types
|
|
6
|
+
from typing import Dict, List, Optional
|
|
9
7
|
|
|
10
8
|
import click
|
|
11
9
|
|
|
12
|
-
import opengradient
|
|
13
|
-
|
|
14
10
|
from .account import EthAccount, generate_eth_account
|
|
15
11
|
from .client import Client
|
|
16
12
|
from .defaults import (
|
|
@@ -20,21 +16,23 @@ from .defaults import (
|
|
|
20
16
|
DEFAULT_OG_FAUCET_URL,
|
|
21
17
|
DEFAULT_RPC_URL,
|
|
22
18
|
)
|
|
23
|
-
from .types import InferenceMode, LlmInferenceMode
|
|
19
|
+
from .types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
|
|
24
20
|
|
|
25
|
-
OG_CONFIG_FILE = Path.home() /
|
|
21
|
+
OG_CONFIG_FILE = Path.home() / ".opengradient_config.json"
|
|
26
22
|
|
|
27
23
|
|
|
28
24
|
def load_og_config():
|
|
29
25
|
if OG_CONFIG_FILE.exists():
|
|
30
|
-
with OG_CONFIG_FILE.open(
|
|
26
|
+
with OG_CONFIG_FILE.open("r") as f:
|
|
31
27
|
return json.load(f)
|
|
32
28
|
return {}
|
|
33
29
|
|
|
30
|
+
|
|
34
31
|
def save_og_config(ctx):
|
|
35
|
-
with OG_CONFIG_FILE.open(
|
|
32
|
+
with OG_CONFIG_FILE.open("w") as f:
|
|
36
33
|
json.dump(ctx.obj, f)
|
|
37
34
|
|
|
35
|
+
|
|
38
36
|
# Convert string to dictionary click parameter typing
|
|
39
37
|
class DictParamType(click.ParamType):
|
|
40
38
|
name = "dictionary"
|
|
@@ -56,6 +54,7 @@ class DictParamType(click.ParamType):
|
|
|
56
54
|
except (ValueError, SyntaxError):
|
|
57
55
|
self.fail(f"'{value}' is not a valid dictionary", param, ctx)
|
|
58
56
|
|
|
57
|
+
|
|
59
58
|
Dict = DictParamType()
|
|
60
59
|
|
|
61
60
|
# Supported inference modes
|
|
@@ -78,6 +77,7 @@ LlmModels = {
|
|
|
78
77
|
"meta-llama/Llama-3.1-70B-Instruct",
|
|
79
78
|
}
|
|
80
79
|
|
|
80
|
+
|
|
81
81
|
def initialize_config(ctx):
|
|
82
82
|
"""Interactively initialize OpenGradient config"""
|
|
83
83
|
if ctx.obj: # Check if config data already exists
|
|
@@ -85,7 +85,7 @@ def initialize_config(ctx):
|
|
|
85
85
|
click.echo("You can view your current config with 'opengradient config show'.")
|
|
86
86
|
|
|
87
87
|
click.echo("Initializing OpenGradient config...")
|
|
88
|
-
click.secho(f"Config will be stored in: {OG_CONFIG_FILE}", fg=
|
|
88
|
+
click.secho(f"Config will be stored in: {OG_CONFIG_FILE}", fg="cyan")
|
|
89
89
|
|
|
90
90
|
# Check if user has an existing account
|
|
91
91
|
has_account = click.confirm("Do you already have an OpenGradient account?", default=True)
|
|
@@ -95,52 +95,56 @@ def initialize_config(ctx):
|
|
|
95
95
|
if eth_account is None:
|
|
96
96
|
click.echo("Account creation cancelled. Config initialization aborted.")
|
|
97
97
|
return
|
|
98
|
-
ctx.obj[
|
|
98
|
+
ctx.obj["private_key"] = eth_account.private_key
|
|
99
99
|
else:
|
|
100
|
-
ctx.obj[
|
|
100
|
+
ctx.obj["private_key"] = click.prompt("Enter your OpenGradient private key", type=str)
|
|
101
101
|
|
|
102
102
|
# Make email and password optional
|
|
103
|
-
email = click.prompt(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
ctx.obj[
|
|
111
|
-
|
|
112
|
-
|
|
103
|
+
email = click.prompt(
|
|
104
|
+
"Enter your OpenGradient Hub email address (optional, press Enter to skip)", type=str, default="", show_default=False
|
|
105
|
+
)
|
|
106
|
+
ctx.obj["email"] = email if email else None
|
|
107
|
+
password = click.prompt(
|
|
108
|
+
"Enter your OpenGradient Hub password (optional, press Enter to skip)", type=str, hide_input=True, default="", show_default=False
|
|
109
|
+
)
|
|
110
|
+
ctx.obj["password"] = password if password else None
|
|
111
|
+
|
|
112
|
+
ctx.obj["rpc_url"] = DEFAULT_RPC_URL
|
|
113
|
+
ctx.obj["contract_address"] = DEFAULT_INFERENCE_CONTRACT_ADDRESS
|
|
114
|
+
|
|
113
115
|
save_og_config(ctx)
|
|
114
116
|
click.echo("Config has been saved.")
|
|
115
|
-
click.secho("You can run 'opengradient config show' to see configs.", fg=
|
|
117
|
+
click.secho("You can run 'opengradient config show' to see configs.", fg="green")
|
|
116
118
|
|
|
117
119
|
|
|
118
120
|
@click.group()
|
|
119
121
|
@click.pass_context
|
|
120
122
|
def cli(ctx):
|
|
121
123
|
"""
|
|
122
|
-
CLI for OpenGradient SDK.
|
|
124
|
+
CLI for OpenGradient SDK.
|
|
123
125
|
|
|
124
126
|
Run 'opengradient config show' to make sure you have configs set up.
|
|
125
|
-
|
|
127
|
+
|
|
126
128
|
Visit https://docs.opengradient.ai/developers/python_sdk/ for more documentation.
|
|
127
129
|
"""
|
|
128
|
-
# Load existing config
|
|
130
|
+
# Load existing config
|
|
129
131
|
ctx.obj = load_og_config()
|
|
130
132
|
|
|
131
|
-
no_client_commands = [
|
|
133
|
+
no_client_commands = ["config", "create-account", "version"]
|
|
132
134
|
|
|
133
135
|
# Only create client if this is not a config management command
|
|
134
136
|
if ctx.invoked_subcommand in no_client_commands:
|
|
135
137
|
return
|
|
136
138
|
|
|
137
|
-
if all(key in ctx.obj for key in [
|
|
139
|
+
if all(key in ctx.obj for key in ["private_key", "rpc_url", "contract_address"]):
|
|
138
140
|
try:
|
|
139
|
-
ctx.obj[
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
141
|
+
ctx.obj["client"] = Client(
|
|
142
|
+
private_key=ctx.obj["private_key"],
|
|
143
|
+
rpc_url=ctx.obj["rpc_url"],
|
|
144
|
+
contract_address=DEFAULT_INFERENCE_CONTRACT_ADDRESS,
|
|
145
|
+
email=ctx.obj.get("email"),
|
|
146
|
+
password=ctx.obj.get("password"),
|
|
147
|
+
)
|
|
144
148
|
except Exception as e:
|
|
145
149
|
click.echo(f"Failed to create OpenGradient client: {str(e)}")
|
|
146
150
|
ctx.exit(1)
|
|
@@ -167,7 +171,7 @@ def init(ctx):
|
|
|
167
171
|
@click.pass_context
|
|
168
172
|
def show(ctx):
|
|
169
173
|
"""Display current config information"""
|
|
170
|
-
click.secho(f"Config file location: {OG_CONFIG_FILE}", fg=
|
|
174
|
+
click.secho(f"Config file location: {OG_CONFIG_FILE}", fg="cyan")
|
|
171
175
|
|
|
172
176
|
if not ctx.obj:
|
|
173
177
|
click.echo("Config is empty. Run 'opengradient config init' to initialize it.")
|
|
@@ -175,13 +179,13 @@ def show(ctx):
|
|
|
175
179
|
|
|
176
180
|
click.echo("Current config:")
|
|
177
181
|
for key, value in ctx.obj.items():
|
|
178
|
-
if key !=
|
|
179
|
-
if (key ==
|
|
182
|
+
if key != "client": # Don't display the client object
|
|
183
|
+
if (key == "password" or key == "private_key") and value is not None:
|
|
180
184
|
click.echo(f"{key}: {'*' * len(value)}") # Mask the password
|
|
181
185
|
elif value is None:
|
|
182
186
|
click.echo(f"{key}: Not set")
|
|
183
187
|
else:
|
|
184
|
-
click.echo(f"{key}: {value}")
|
|
188
|
+
click.echo(f"{key}: {value}")
|
|
185
189
|
|
|
186
190
|
|
|
187
191
|
@config.command()
|
|
@@ -201,8 +205,8 @@ def clear(ctx):
|
|
|
201
205
|
|
|
202
206
|
|
|
203
207
|
@cli.command()
|
|
204
|
-
@click.option(
|
|
205
|
-
@click.option(
|
|
208
|
+
@click.option("--repo", "-r", "--name", "repo_name", required=True, help="Name of the new model repository")
|
|
209
|
+
@click.option("--description", "-d", required=True, help="Description of the model")
|
|
206
210
|
@click.pass_obj
|
|
207
211
|
def create_model_repo(obj, repo_name: str, description: str):
|
|
208
212
|
"""
|
|
@@ -217,7 +221,7 @@ def create_model_repo(obj, repo_name: str, description: str):
|
|
|
217
221
|
opengradient create-model-repo --name "my_new_model" --description "A new model for XYZ task"
|
|
218
222
|
opengradient create-model-repo -n "my_new_model" -d "A new model for XYZ task"
|
|
219
223
|
"""
|
|
220
|
-
client: Client = obj[
|
|
224
|
+
client: Client = obj["client"]
|
|
221
225
|
|
|
222
226
|
try:
|
|
223
227
|
result = client.create_model(repo_name, description)
|
|
@@ -227,14 +231,14 @@ def create_model_repo(obj, repo_name: str, description: str):
|
|
|
227
231
|
|
|
228
232
|
|
|
229
233
|
@cli.command()
|
|
230
|
-
@click.option(
|
|
231
|
-
@click.option(
|
|
232
|
-
@click.option(
|
|
234
|
+
@click.option("--repo", "-r", "repo_name", required=True, help="Name of the existing model repository")
|
|
235
|
+
@click.option("--notes", "-n", help="Version notes (optional)")
|
|
236
|
+
@click.option("--major", "-m", is_flag=True, default=False, help="Flag to indicate a major version update")
|
|
233
237
|
@click.pass_obj
|
|
234
238
|
def create_version(obj, repo_name: str, notes: str, major: bool):
|
|
235
239
|
"""Create a new version in an existing model repository.
|
|
236
240
|
|
|
237
|
-
This command creates a new version for the specified model repository.
|
|
241
|
+
This command creates a new version for the specified model repository.
|
|
238
242
|
You can optionally provide version notes and indicate if it's a major version update.
|
|
239
243
|
|
|
240
244
|
Example usage:
|
|
@@ -243,7 +247,7 @@ def create_version(obj, repo_name: str, notes: str, major: bool):
|
|
|
243
247
|
opengradient create-version --repo my_model_repo --notes "Added new feature X" --major
|
|
244
248
|
opengradient create-version -r my_model_repo -n "Bug fixes"
|
|
245
249
|
"""
|
|
246
|
-
client: Client = obj[
|
|
250
|
+
client: Client = obj["client"]
|
|
247
251
|
|
|
248
252
|
try:
|
|
249
253
|
result = client.create_version(repo_name, notes, major)
|
|
@@ -253,10 +257,11 @@ def create_version(obj, repo_name: str, notes: str, major: bool):
|
|
|
253
257
|
|
|
254
258
|
|
|
255
259
|
@cli.command()
|
|
256
|
-
@click.argument(
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
@click.option(
|
|
260
|
+
@click.argument(
|
|
261
|
+
"file_path", type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, path_type=Path), metavar="FILE_PATH"
|
|
262
|
+
)
|
|
263
|
+
@click.option("--repo", "-r", "repo_name", required=True, help="Name of the model repository")
|
|
264
|
+
@click.option("--version", "-v", required=True, help='Version of the model (e.g., "0.01")')
|
|
260
265
|
@click.pass_obj
|
|
261
266
|
def upload_file(obj, file_path: Path, repo_name: str, version: str):
|
|
262
267
|
"""
|
|
@@ -270,7 +275,7 @@ def upload_file(obj, file_path: Path, repo_name: str, version: str):
|
|
|
270
275
|
opengradient upload-file path/to/model.onnx --repo my_model_repo --version 0.01
|
|
271
276
|
opengradient upload-file path/to/model.onnx -r my_model_repo -v 0.01
|
|
272
277
|
"""
|
|
273
|
-
client: Client = obj[
|
|
278
|
+
client: Client = obj["client"]
|
|
274
279
|
|
|
275
280
|
try:
|
|
276
281
|
result = client.upload(file_path, repo_name, version)
|
|
@@ -280,13 +285,17 @@ def upload_file(obj, file_path: Path, repo_name: str, version: str):
|
|
|
280
285
|
|
|
281
286
|
|
|
282
287
|
@cli.command()
|
|
283
|
-
@click.option(
|
|
284
|
-
@click.option(
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
@click.option(
|
|
288
|
-
|
|
289
|
-
|
|
288
|
+
@click.option("--model", "-m", "model_cid", required=True, help="CID of the model to run inference on")
|
|
289
|
+
@click.option(
|
|
290
|
+
"--mode", "inference_mode", type=click.Choice(InferenceModes.keys()), default="VANILLA", help="Inference mode (default: VANILLA)"
|
|
291
|
+
)
|
|
292
|
+
@click.option("--input", "-d", "input_data", type=Dict, help="Input data for inference as a JSON string")
|
|
293
|
+
@click.option(
|
|
294
|
+
"--input-file",
|
|
295
|
+
"-f",
|
|
296
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, path_type=Path),
|
|
297
|
+
help="JSON file containing input data for inference",
|
|
298
|
+
)
|
|
290
299
|
@click.pass_context
|
|
291
300
|
def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path):
|
|
292
301
|
"""
|
|
@@ -301,27 +310,27 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
|
|
|
301
310
|
opengradient infer --model Qm... --mode VANILLA --input '{"key": "value"}'
|
|
302
311
|
opengradient infer -m Qm... -i ZKML -f input_data.json
|
|
303
312
|
"""
|
|
304
|
-
client: Client = ctx.obj[
|
|
313
|
+
client: Client = ctx.obj["client"]
|
|
305
314
|
|
|
306
315
|
try:
|
|
307
316
|
if not input_data and not input_file:
|
|
308
317
|
click.echo("Must specify either input_data or input_file")
|
|
309
318
|
ctx.exit(1)
|
|
310
319
|
return
|
|
311
|
-
|
|
320
|
+
|
|
312
321
|
if input_data and input_file:
|
|
313
322
|
click.echo("Cannot have both input_data and input_file")
|
|
314
323
|
ctx.exit(1)
|
|
315
324
|
return
|
|
316
|
-
|
|
325
|
+
|
|
317
326
|
if input_data:
|
|
318
327
|
model_input = input_data
|
|
319
328
|
|
|
320
329
|
if input_file:
|
|
321
|
-
with input_file.open(
|
|
330
|
+
with input_file.open("r") as file:
|
|
322
331
|
model_input = json.load(file)
|
|
323
|
-
|
|
324
|
-
click.echo(f
|
|
332
|
+
|
|
333
|
+
click.echo(f'Running {inference_mode} inference for model "{model_cid}"')
|
|
325
334
|
tx_hash, model_output = client.infer(model_cid=model_cid, inference_mode=InferenceModes[inference_mode], model_input=model_input)
|
|
326
335
|
|
|
327
336
|
click.echo() # Add a newline for better spacing
|
|
@@ -336,7 +345,7 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
|
|
|
336
345
|
click.echo()
|
|
337
346
|
|
|
338
347
|
click.secho("Inference result:", fg="green")
|
|
339
|
-
formatted_output = json.dumps(model_output, indent=2, default=lambda x: x.tolist() if hasattr(x,
|
|
348
|
+
formatted_output = json.dumps(model_output, indent=2, default=lambda x: x.tolist() if hasattr(x, "tolist") else str(x))
|
|
340
349
|
click.echo(formatted_output)
|
|
341
350
|
except json.JSONDecodeError as e:
|
|
342
351
|
click.echo(f"Error decoding JSON: {e}", err=True)
|
|
@@ -344,16 +353,25 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
|
|
|
344
353
|
except Exception as e:
|
|
345
354
|
click.echo(f"Error running inference: {str(e)}")
|
|
346
355
|
|
|
356
|
+
|
|
347
357
|
@cli.command()
|
|
348
|
-
@click.option(
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
358
|
+
@click.option(
|
|
359
|
+
"--model",
|
|
360
|
+
"-m",
|
|
361
|
+
"model_cid",
|
|
362
|
+
type=click.Choice([e.value for e in LLM]),
|
|
363
|
+
required=True,
|
|
364
|
+
help="CID of the LLM model to run inference on",
|
|
365
|
+
)
|
|
366
|
+
@click.option(
|
|
367
|
+
"--mode", "inference_mode", type=click.Choice(LlmInferenceModes.keys()), default="VANILLA", help="Inference mode (default: VANILLA)"
|
|
368
|
+
)
|
|
369
|
+
@click.option("--prompt", "-p", required=True, help="Input prompt for the LLM completion")
|
|
370
|
+
@click.option("--max-tokens", type=int, default=100, help="Maximum number of tokens for LLM completion output")
|
|
371
|
+
@click.option("--stop-sequence", multiple=True, help="Stop sequences for LLM")
|
|
372
|
+
@click.option("--temperature", type=float, default=0.0, help="Temperature for LLM inference (0.0 to 1.0)")
|
|
355
373
|
@click.pass_context
|
|
356
|
-
def completion(ctx, model_cid: str, inference_mode: str,
|
|
374
|
+
def completion(ctx, model_cid: str, inference_mode: str, prompt: str, max_tokens: int, stop_sequence: List[str], temperature: float):
|
|
357
375
|
"""
|
|
358
376
|
Run completion inference on an LLM model.
|
|
359
377
|
|
|
@@ -365,22 +383,23 @@ def completion(ctx, model_cid: str, inference_mode: str, prompt: str, max_token
|
|
|
365
383
|
opengradient completion --model meta-llama/Meta-Llama-3-8B-Instruct --prompt "Hello, how are you?" --max-tokens 50 --temperature 0.7
|
|
366
384
|
opengradient completion -m meta-llama/Meta-Llama-3-8B-Instruct -p "Translate to French: Hello world" --stop-sequence "." --stop-sequence "\\n"
|
|
367
385
|
"""
|
|
368
|
-
client: Client = ctx.obj[
|
|
386
|
+
client: Client = ctx.obj["client"]
|
|
369
387
|
try:
|
|
370
|
-
click.echo(f
|
|
388
|
+
click.echo(f'Running LLM completion inference for model "{model_cid}"\n')
|
|
371
389
|
tx_hash, llm_output = client.llm_completion(
|
|
372
390
|
model_cid=model_cid,
|
|
373
391
|
inference_mode=LlmInferenceModes[inference_mode],
|
|
374
392
|
prompt=prompt,
|
|
375
393
|
max_tokens=max_tokens,
|
|
376
394
|
stop_sequence=list(stop_sequence),
|
|
377
|
-
temperature=temperature
|
|
395
|
+
temperature=temperature,
|
|
378
396
|
)
|
|
379
397
|
|
|
380
398
|
print_llm_completion_result(model_cid, tx_hash, llm_output)
|
|
381
399
|
except Exception as e:
|
|
382
400
|
click.echo(f"Error running LLM completion: {str(e)}")
|
|
383
401
|
|
|
402
|
+
|
|
384
403
|
def print_llm_completion_result(model_cid, tx_hash, llm_output):
|
|
385
404
|
click.secho("✅ LLM completion Successful", fg="green", bold=True)
|
|
386
405
|
click.echo("──────────────────────────────────────")
|
|
@@ -399,46 +418,32 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output):
|
|
|
399
418
|
|
|
400
419
|
|
|
401
420
|
@cli.command()
|
|
402
|
-
@click.option(
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
@click.option(
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
@click.option(
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
default=0.0,
|
|
429
|
-
help='Temperature for LLM inference (0.0 to 1.0)')
|
|
430
|
-
@click.option('--tools',
|
|
431
|
-
type=str,
|
|
432
|
-
default=None,
|
|
433
|
-
help='Tool configurations in JSON format')
|
|
434
|
-
@click.option('--tools-file',
|
|
435
|
-
type=click.Path(exists=True, path_type=Path),
|
|
436
|
-
required=False,
|
|
437
|
-
help='Path to JSON file containing tool configurations')
|
|
438
|
-
@click.option('--tool-choice',
|
|
439
|
-
type=str,
|
|
440
|
-
default='',
|
|
441
|
-
help='Specific tool choice for the LLM')
|
|
421
|
+
@click.option(
|
|
422
|
+
"--model",
|
|
423
|
+
"-m",
|
|
424
|
+
"model_cid",
|
|
425
|
+
type=click.Choice([e.value for e in LLM]),
|
|
426
|
+
required=True,
|
|
427
|
+
help="CID of the LLM model to run inference on",
|
|
428
|
+
)
|
|
429
|
+
@click.option(
|
|
430
|
+
"--mode", "inference_mode", type=click.Choice(LlmInferenceModes.keys()), default="VANILLA", help="Inference mode (default: VANILLA)"
|
|
431
|
+
)
|
|
432
|
+
@click.option("--messages", type=str, required=False, help="Input messages for the chat inference in JSON format")
|
|
433
|
+
@click.option(
|
|
434
|
+
"--messages-file",
|
|
435
|
+
type=click.Path(exists=True, path_type=Path),
|
|
436
|
+
required=False,
|
|
437
|
+
help="Path to JSON file containing input messages for the chat inference",
|
|
438
|
+
)
|
|
439
|
+
@click.option("--max-tokens", type=int, default=100, help="Maximum number of tokens for LLM output")
|
|
440
|
+
@click.option("--stop-sequence", type=str, default=None, multiple=True, help="Stop sequences for LLM")
|
|
441
|
+
@click.option("--temperature", type=float, default=0.0, help="Temperature for LLM inference (0.0 to 1.0)")
|
|
442
|
+
@click.option("--tools", type=str, default=None, help="Tool configurations in JSON format")
|
|
443
|
+
@click.option(
|
|
444
|
+
"--tools-file", type=click.Path(exists=True, path_type=Path), required=False, help="Path to JSON file containing tool configurations"
|
|
445
|
+
)
|
|
446
|
+
@click.option("--tool-choice", type=str, default="", help="Specific tool choice for the LLM")
|
|
442
447
|
@click.pass_context
|
|
443
448
|
def chat(
|
|
444
449
|
ctx,
|
|
@@ -451,7 +456,8 @@ def chat(
|
|
|
451
456
|
temperature: float,
|
|
452
457
|
tools: Optional[str],
|
|
453
458
|
tools_file: Optional[Path],
|
|
454
|
-
tool_choice: Optional[str]
|
|
459
|
+
tool_choice: Optional[str],
|
|
460
|
+
):
|
|
455
461
|
"""
|
|
456
462
|
Run chat inference on an LLM model.
|
|
457
463
|
|
|
@@ -465,9 +471,9 @@ def chat(
|
|
|
465
471
|
opengradient chat --model meta-llama/Meta-Llama-3-8B-Instruct --messages '[{"role":"user","content":"hello"}]' --max-tokens 50 --temperature 0.7
|
|
466
472
|
opengradient chat --model mistralai/Mistral-7B-Instruct-v0.3 --messages-file messages.json --tools-file tools.json --max-tokens 200 --stop-sequence "." --stop-sequence "\\n"
|
|
467
473
|
"""
|
|
468
|
-
client: Client = ctx.obj[
|
|
474
|
+
client: Client = ctx.obj["client"]
|
|
469
475
|
try:
|
|
470
|
-
click.echo(f
|
|
476
|
+
click.echo(f'Running LLM chat inference for model "{model_cid}"\n')
|
|
471
477
|
if not messages and not messages_file:
|
|
472
478
|
click.echo("Must specify either messages or messages-file")
|
|
473
479
|
ctx.exit(1)
|
|
@@ -484,16 +490,16 @@ def chat(
|
|
|
484
490
|
click.echo(f"Failed to parse messages: {e}")
|
|
485
491
|
ctx.exit(1)
|
|
486
492
|
else:
|
|
487
|
-
with messages_file.open(
|
|
493
|
+
with messages_file.open("r") as file:
|
|
488
494
|
messages = json.load(file)
|
|
489
495
|
|
|
490
496
|
# Parse tools if provided
|
|
491
|
-
if (tools and tools !=
|
|
497
|
+
if (tools and tools != "[]") and tools_file:
|
|
492
498
|
click.echo("Cannot have both tools and tools-file")
|
|
493
499
|
click.exit(1)
|
|
494
500
|
return
|
|
495
|
-
|
|
496
|
-
parsed_tools=[]
|
|
501
|
+
|
|
502
|
+
parsed_tools = []
|
|
497
503
|
if tools:
|
|
498
504
|
try:
|
|
499
505
|
parsed_tools = json.loads(tools)
|
|
@@ -508,7 +514,7 @@ def chat(
|
|
|
508
514
|
|
|
509
515
|
if tools_file:
|
|
510
516
|
try:
|
|
511
|
-
with tools_file.open(
|
|
517
|
+
with tools_file.open("r") as file:
|
|
512
518
|
parsed_tools = json.load(file)
|
|
513
519
|
if not isinstance(parsed_tools, list):
|
|
514
520
|
click.echo("Tools must be a JSON array")
|
|
@@ -518,7 +524,7 @@ def chat(
|
|
|
518
524
|
click.echo("Failed to load JSON from tools_file: %s" % e)
|
|
519
525
|
ctx.exit(1)
|
|
520
526
|
return
|
|
521
|
-
|
|
527
|
+
|
|
522
528
|
if not tools and not tools_file:
|
|
523
529
|
parsed_tools = None
|
|
524
530
|
|
|
@@ -537,6 +543,7 @@ def chat(
|
|
|
537
543
|
except Exception as e:
|
|
538
544
|
click.echo(f"Error running LLM chat inference: {str(e)}")
|
|
539
545
|
|
|
546
|
+
|
|
540
547
|
def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output):
|
|
541
548
|
click.secho("✅ LLM Chat Successful", fg="green", bold=True)
|
|
542
549
|
click.echo("──────────────────────────────────────")
|
|
@@ -556,10 +563,11 @@ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output):
|
|
|
556
563
|
click.echo()
|
|
557
564
|
for key, value in chat_output.items():
|
|
558
565
|
# If the value doesn't give any information, don't print it
|
|
559
|
-
if value != None and value != "" and value !=
|
|
566
|
+
if value != None and value != "" and value != "[]" and value != []:
|
|
560
567
|
click.echo(f"{key}: {value}")
|
|
561
568
|
click.echo()
|
|
562
569
|
|
|
570
|
+
|
|
563
571
|
@cli.command()
|
|
564
572
|
def create_account():
|
|
565
573
|
"""Create a new test account for OpenGradient inference and model management"""
|
|
@@ -596,22 +604,16 @@ def create_account_impl() -> EthAccount:
|
|
|
596
604
|
click.echo("Account Creation Complete!".center(50))
|
|
597
605
|
click.echo("=" * 50)
|
|
598
606
|
click.echo("\nYour OpenGradient account has been successfully created and funded.")
|
|
599
|
-
click.secho(f"Address: {eth_account.address}", fg=
|
|
600
|
-
click.secho(f"Private Key: {eth_account.private_key}", fg=
|
|
601
|
-
click.secho("\nPlease save this information for your records.\n", fg=
|
|
607
|
+
click.secho(f"Address: {eth_account.address}", fg="green")
|
|
608
|
+
click.secho(f"Private Key: {eth_account.private_key}", fg="green")
|
|
609
|
+
click.secho("\nPlease save this information for your records.\n", fg="cyan")
|
|
602
610
|
|
|
603
611
|
return eth_account
|
|
604
612
|
|
|
605
613
|
|
|
606
614
|
@cli.command()
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
click.echo(f"OpenGradient CLI version: {opengradient.__version__}")
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
@cli.command()
|
|
613
|
-
@click.option('--repo', '-r', 'repo_name', required=True, help='Name of the model repository')
|
|
614
|
-
@click.option('--version', '-v', required=True, help='Version of the model (e.g., "0.01")')
|
|
615
|
+
@click.option("--repo", "-r", "repo_name", required=True, help="Name of the model repository")
|
|
616
|
+
@click.option("--version", "-v", required=True, help='Version of the model (e.g., "0.01")')
|
|
615
617
|
@click.pass_obj
|
|
616
618
|
def list_files(client: Client, repo_name: str, version: str):
|
|
617
619
|
"""
|
|
@@ -638,33 +640,27 @@ def list_files(client: Client, repo_name: str, version: str):
|
|
|
638
640
|
|
|
639
641
|
|
|
640
642
|
@cli.command()
|
|
641
|
-
@click.option(
|
|
642
|
-
@click.option(
|
|
643
|
-
@click.option(
|
|
644
|
-
|
|
645
|
-
@click.option(
|
|
646
|
-
@click.option('--height', type=int, default=1024, help='Output image height')
|
|
643
|
+
@click.option("--model", "-m", required=True, help="Model identifier for image generation")
|
|
644
|
+
@click.option("--prompt", "-p", required=True, help="Text prompt for generating the image")
|
|
645
|
+
@click.option("--output-path", "-o", required=True, type=click.Path(path_type=Path), help="Output file path for the generated image")
|
|
646
|
+
@click.option("--width", type=int, default=1024, help="Output image width")
|
|
647
|
+
@click.option("--height", type=int, default=1024, help="Output image height")
|
|
647
648
|
@click.pass_context
|
|
648
649
|
def generate_image(ctx, model: str, prompt: str, output_path: Path, width: int, height: int):
|
|
649
650
|
"""
|
|
650
651
|
Generate an image using a diffusion model.
|
|
651
652
|
|
|
652
653
|
Example usage:
|
|
653
|
-
opengradient generate-image --model stabilityai/stable-diffusion-xl-base-1.0
|
|
654
|
+
opengradient generate-image --model stabilityai/stable-diffusion-xl-base-1.0
|
|
654
655
|
--prompt "A beautiful sunset over mountains" --output-path sunset.png
|
|
655
656
|
"""
|
|
656
|
-
client: Client = ctx.obj[
|
|
657
|
+
client: Client = ctx.obj["client"]
|
|
657
658
|
try:
|
|
658
|
-
click.echo(f
|
|
659
|
-
image_data = client.generate_image(
|
|
660
|
-
model_cid=model,
|
|
661
|
-
prompt=prompt,
|
|
662
|
-
width=width,
|
|
663
|
-
height=height
|
|
664
|
-
)
|
|
659
|
+
click.echo(f'Generating image with model "{model}"')
|
|
660
|
+
image_data = client.generate_image(model_cid=model, prompt=prompt, width=width, height=height)
|
|
665
661
|
|
|
666
662
|
# Save the image
|
|
667
|
-
with open(output_path,
|
|
663
|
+
with open(output_path, "wb") as f:
|
|
668
664
|
f.write(image_data)
|
|
669
665
|
|
|
670
666
|
click.echo() # Add a newline for better spacing
|
|
@@ -674,6 +670,7 @@ def generate_image(ctx, model: str, prompt: str, output_path: Path, width: int,
|
|
|
674
670
|
except Exception as e:
|
|
675
671
|
click.echo(f"Error generating image: {str(e)}")
|
|
676
672
|
|
|
677
|
-
|
|
673
|
+
|
|
674
|
+
if __name__ == "__main__":
|
|
678
675
|
logging.getLogger().setLevel(logging.WARN)
|
|
679
|
-
cli()
|
|
676
|
+
cli()
|