opengradient 0.3.24__py3-none-any.whl → 0.3.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +125 -98
- opengradient/account.py +6 -4
- opengradient/cli.py +151 -154
- opengradient/client.py +300 -362
- opengradient/defaults.py +7 -7
- opengradient/exceptions.py +25 -0
- opengradient/llm/__init__.py +7 -10
- opengradient/llm/og_langchain.py +34 -51
- opengradient/llm/og_openai.py +54 -61
- opengradient/mltools/__init__.py +2 -7
- opengradient/mltools/model_tool.py +20 -26
- opengradient/proto/infer_pb2.py +24 -29
- opengradient/proto/infer_pb2_grpc.py +95 -86
- opengradient/types.py +39 -35
- opengradient/utils.py +30 -31
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/METADATA +5 -92
- opengradient-0.3.26.dist-info/RECORD +26 -0
- opengradient-0.3.24.dist-info/RECORD +0 -26
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/LICENSE +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/WHEEL +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/entry_points.txt +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/top_level.txt +0 -0
opengradient/proto/infer_pb2.py
CHANGED
|
@@ -4,47 +4,42 @@
|
|
|
4
4
|
# source: infer.proto
|
|
5
5
|
# Protobuf Python Version: 5.27.2
|
|
6
6
|
"""Generated protocol buffer code."""
|
|
7
|
+
|
|
7
8
|
from google.protobuf import descriptor as _descriptor
|
|
8
9
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
10
|
from google.protobuf import runtime_version as _runtime_version
|
|
10
11
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
12
|
from google.protobuf.internal import builder as _builder
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
5,
|
|
15
|
-
27,
|
|
16
|
-
2,
|
|
17
|
-
'',
|
|
18
|
-
'infer.proto'
|
|
19
|
-
)
|
|
13
|
+
|
|
14
|
+
_runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, "", "infer.proto")
|
|
20
15
|
# @@protoc_insertion_point(imports)
|
|
21
16
|
|
|
22
17
|
_sym_db = _symbol_database.Default()
|
|
23
18
|
|
|
24
19
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
|
21
|
+
b'\n\x0binfer.proto\x12\tinference"[\n\x10InferenceRequest\x12\n\n\x02tx\x18\x01 \x01(\t\x12;\n\x10image_generation\x18\x06 \x01(\x0b\x32!.inference.ImageGenerationRequest"u\n\x16ImageGenerationRequest\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x13\n\x06height\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x12\n\x05width\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\t\n\x07_heightB\x08\n\x06_width"\x1b\n\rInferenceTxId\x12\n\n\x02id\x18\x01 \x01(\t"\xd4\x01\n\x0fInferenceStatus\x12\x31\n\x06status\x18\x01 \x01(\x0e\x32!.inference.InferenceStatus.Status\x12\x1a\n\rerror_message\x18\x02 \x01(\tH\x00\x88\x01\x01"`\n\x06Status\x12\x16\n\x12STATUS_UNSPECIFIED\x10\x00\x12\x16\n\x12STATUS_IN_PROGRESS\x10\x01\x12\x14\n\x10STATUS_COMPLETED\x10\x02\x12\x10\n\x0cSTATUS_ERROR\x10\x03\x42\x10\n\x0e_error_message"\xa4\x01\n\x0fInferenceResult\x12\x43\n\x17image_generation_result\x18\x05 \x01(\x0b\x32".inference.ImageGenerationResponse\x12\x17\n\npublic_key\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tsignature\x18\x08 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_public_keyB\x0c\n\n_signature"-\n\x17ImageGenerationResponse\x12\x12\n\nimage_data\x18\x01 \x01(\x0c\x32\xf6\x01\n\x10InferenceService\x12J\n\x11RunInferenceAsync\x12\x1b.inference.InferenceRequest\x1a\x18.inference.InferenceTxId\x12J\n\x12GetInferenceStatus\x12\x18.inference.InferenceTxId\x1a\x1a.inference.InferenceStatus\x12J\n\x12GetInferenceResult\x12\x18.inference.InferenceTxId\x1a\x1a.inference.InferenceResultb\x06proto3'
|
|
22
|
+
)
|
|
28
23
|
|
|
29
24
|
_globals = globals()
|
|
30
25
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
-
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR,
|
|
26
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "infer_pb2", _globals)
|
|
32
27
|
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
28
|
+
DESCRIPTOR._loaded_options = None
|
|
29
|
+
_globals["_INFERENCEREQUEST"]._serialized_start = 26
|
|
30
|
+
_globals["_INFERENCEREQUEST"]._serialized_end = 117
|
|
31
|
+
_globals["_IMAGEGENERATIONREQUEST"]._serialized_start = 119
|
|
32
|
+
_globals["_IMAGEGENERATIONREQUEST"]._serialized_end = 236
|
|
33
|
+
_globals["_INFERENCETXID"]._serialized_start = 238
|
|
34
|
+
_globals["_INFERENCETXID"]._serialized_end = 265
|
|
35
|
+
_globals["_INFERENCESTATUS"]._serialized_start = 268
|
|
36
|
+
_globals["_INFERENCESTATUS"]._serialized_end = 480
|
|
37
|
+
_globals["_INFERENCESTATUS_STATUS"]._serialized_start = 366
|
|
38
|
+
_globals["_INFERENCESTATUS_STATUS"]._serialized_end = 462
|
|
39
|
+
_globals["_INFERENCERESULT"]._serialized_start = 483
|
|
40
|
+
_globals["_INFERENCERESULT"]._serialized_end = 647
|
|
41
|
+
_globals["_IMAGEGENERATIONRESPONSE"]._serialized_start = 649
|
|
42
|
+
_globals["_IMAGEGENERATIONRESPONSE"]._serialized_end = 694
|
|
43
|
+
_globals["_INFERENCESERVICE"]._serialized_start = 697
|
|
44
|
+
_globals["_INFERENCESERVICE"]._serialized_end = 943
|
|
50
45
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,33 +1,33 @@
|
|
|
1
1
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
2
2
|
"""Client and server classes corresponding to protobuf-defined services."""
|
|
3
|
+
|
|
3
4
|
import grpc
|
|
4
|
-
import warnings
|
|
5
5
|
|
|
6
6
|
from . import infer_pb2 as infer__pb2
|
|
7
7
|
|
|
8
|
-
GRPC_GENERATED_VERSION =
|
|
8
|
+
GRPC_GENERATED_VERSION = "1.66.2"
|
|
9
9
|
GRPC_VERSION = grpc.__version__
|
|
10
10
|
_version_not_supported = False
|
|
11
11
|
|
|
12
12
|
try:
|
|
13
13
|
from grpc._utilities import first_version_is_lower
|
|
14
|
+
|
|
14
15
|
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
15
16
|
except ImportError:
|
|
16
17
|
_version_not_supported = True
|
|
17
18
|
|
|
18
19
|
if _version_not_supported:
|
|
19
20
|
raise RuntimeError(
|
|
20
|
-
f
|
|
21
|
-
+
|
|
22
|
-
+ f
|
|
23
|
-
+ f
|
|
24
|
-
+ f
|
|
21
|
+
f"The grpc package installed is at version {GRPC_VERSION},"
|
|
22
|
+
+ " but the generated code in infer_pb2_grpc.py depends on"
|
|
23
|
+
+ f" grpcio>={GRPC_GENERATED_VERSION}."
|
|
24
|
+
+ f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
|
|
25
|
+
+ f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
|
|
25
26
|
)
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class InferenceServiceStub(object):
|
|
29
|
-
"""The inference service definition
|
|
30
|
-
"""
|
|
30
|
+
"""The inference service definition"""
|
|
31
31
|
|
|
32
32
|
def __init__(self, channel):
|
|
33
33
|
"""Constructor.
|
|
@@ -36,89 +36,91 @@ class InferenceServiceStub(object):
|
|
|
36
36
|
channel: A grpc.Channel.
|
|
37
37
|
"""
|
|
38
38
|
self.RunInferenceAsync = channel.unary_unary(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
39
|
+
"/inference.InferenceService/RunInferenceAsync",
|
|
40
|
+
request_serializer=infer__pb2.InferenceRequest.SerializeToString,
|
|
41
|
+
response_deserializer=infer__pb2.InferenceTxId.FromString,
|
|
42
|
+
_registered_method=True,
|
|
43
|
+
)
|
|
43
44
|
self.GetInferenceStatus = channel.unary_unary(
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
45
|
+
"/inference.InferenceService/GetInferenceStatus",
|
|
46
|
+
request_serializer=infer__pb2.InferenceTxId.SerializeToString,
|
|
47
|
+
response_deserializer=infer__pb2.InferenceStatus.FromString,
|
|
48
|
+
_registered_method=True,
|
|
49
|
+
)
|
|
48
50
|
self.GetInferenceResult = channel.unary_unary(
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
"/inference.InferenceService/GetInferenceResult",
|
|
52
|
+
request_serializer=infer__pb2.InferenceTxId.SerializeToString,
|
|
53
|
+
response_deserializer=infer__pb2.InferenceResult.FromString,
|
|
54
|
+
_registered_method=True,
|
|
55
|
+
)
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
class InferenceServiceServicer(object):
|
|
56
|
-
"""The inference service definition
|
|
57
|
-
"""
|
|
59
|
+
"""The inference service definition"""
|
|
58
60
|
|
|
59
61
|
def RunInferenceAsync(self, request, context):
|
|
60
62
|
"""Missing associated documentation comment in .proto file."""
|
|
61
63
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
62
|
-
context.set_details(
|
|
63
|
-
raise NotImplementedError(
|
|
64
|
+
context.set_details("Method not implemented!")
|
|
65
|
+
raise NotImplementedError("Method not implemented!")
|
|
64
66
|
|
|
65
67
|
def GetInferenceStatus(self, request, context):
|
|
66
68
|
"""Missing associated documentation comment in .proto file."""
|
|
67
69
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
68
|
-
context.set_details(
|
|
69
|
-
raise NotImplementedError(
|
|
70
|
+
context.set_details("Method not implemented!")
|
|
71
|
+
raise NotImplementedError("Method not implemented!")
|
|
70
72
|
|
|
71
73
|
def GetInferenceResult(self, request, context):
|
|
72
74
|
"""Missing associated documentation comment in .proto file."""
|
|
73
75
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
74
|
-
context.set_details(
|
|
75
|
-
raise NotImplementedError(
|
|
76
|
+
context.set_details("Method not implemented!")
|
|
77
|
+
raise NotImplementedError("Method not implemented!")
|
|
76
78
|
|
|
77
79
|
|
|
78
80
|
def add_InferenceServiceServicer_to_server(servicer, server):
|
|
79
81
|
rpc_method_handlers = {
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
82
|
+
"RunInferenceAsync": grpc.unary_unary_rpc_method_handler(
|
|
83
|
+
servicer.RunInferenceAsync,
|
|
84
|
+
request_deserializer=infer__pb2.InferenceRequest.FromString,
|
|
85
|
+
response_serializer=infer__pb2.InferenceTxId.SerializeToString,
|
|
86
|
+
),
|
|
87
|
+
"GetInferenceStatus": grpc.unary_unary_rpc_method_handler(
|
|
88
|
+
servicer.GetInferenceStatus,
|
|
89
|
+
request_deserializer=infer__pb2.InferenceTxId.FromString,
|
|
90
|
+
response_serializer=infer__pb2.InferenceStatus.SerializeToString,
|
|
91
|
+
),
|
|
92
|
+
"GetInferenceResult": grpc.unary_unary_rpc_method_handler(
|
|
93
|
+
servicer.GetInferenceResult,
|
|
94
|
+
request_deserializer=infer__pb2.InferenceTxId.FromString,
|
|
95
|
+
response_serializer=infer__pb2.InferenceResult.SerializeToString,
|
|
96
|
+
),
|
|
95
97
|
}
|
|
96
|
-
generic_handler = grpc.method_handlers_generic_handler(
|
|
97
|
-
'inference.InferenceService', rpc_method_handlers)
|
|
98
|
+
generic_handler = grpc.method_handlers_generic_handler("inference.InferenceService", rpc_method_handlers)
|
|
98
99
|
server.add_generic_rpc_handlers((generic_handler,))
|
|
99
|
-
server.add_registered_method_handlers(
|
|
100
|
+
server.add_registered_method_handlers("inference.InferenceService", rpc_method_handlers)
|
|
100
101
|
|
|
101
102
|
|
|
102
|
-
|
|
103
|
+
# This class is part of an EXPERIMENTAL API.
|
|
103
104
|
class InferenceService(object):
|
|
104
|
-
"""The inference service definition
|
|
105
|
-
"""
|
|
105
|
+
"""The inference service definition"""
|
|
106
106
|
|
|
107
107
|
@staticmethod
|
|
108
|
-
def RunInferenceAsync(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
108
|
+
def RunInferenceAsync(
|
|
109
|
+
request,
|
|
110
|
+
target,
|
|
111
|
+
options=(),
|
|
112
|
+
channel_credentials=None,
|
|
113
|
+
call_credentials=None,
|
|
114
|
+
insecure=False,
|
|
115
|
+
compression=None,
|
|
116
|
+
wait_for_ready=None,
|
|
117
|
+
timeout=None,
|
|
118
|
+
metadata=None,
|
|
119
|
+
):
|
|
118
120
|
return grpc.experimental.unary_unary(
|
|
119
121
|
request,
|
|
120
122
|
target,
|
|
121
|
-
|
|
123
|
+
"/inference.InferenceService/RunInferenceAsync",
|
|
122
124
|
infer__pb2.InferenceRequest.SerializeToString,
|
|
123
125
|
infer__pb2.InferenceTxId.FromString,
|
|
124
126
|
options,
|
|
@@ -129,23 +131,26 @@ class InferenceService(object):
|
|
|
129
131
|
wait_for_ready,
|
|
130
132
|
timeout,
|
|
131
133
|
metadata,
|
|
132
|
-
_registered_method=True
|
|
134
|
+
_registered_method=True,
|
|
135
|
+
)
|
|
133
136
|
|
|
134
137
|
@staticmethod
|
|
135
|
-
def GetInferenceStatus(
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
138
|
+
def GetInferenceStatus(
|
|
139
|
+
request,
|
|
140
|
+
target,
|
|
141
|
+
options=(),
|
|
142
|
+
channel_credentials=None,
|
|
143
|
+
call_credentials=None,
|
|
144
|
+
insecure=False,
|
|
145
|
+
compression=None,
|
|
146
|
+
wait_for_ready=None,
|
|
147
|
+
timeout=None,
|
|
148
|
+
metadata=None,
|
|
149
|
+
):
|
|
145
150
|
return grpc.experimental.unary_unary(
|
|
146
151
|
request,
|
|
147
152
|
target,
|
|
148
|
-
|
|
153
|
+
"/inference.InferenceService/GetInferenceStatus",
|
|
149
154
|
infer__pb2.InferenceTxId.SerializeToString,
|
|
150
155
|
infer__pb2.InferenceStatus.FromString,
|
|
151
156
|
options,
|
|
@@ -156,23 +161,26 @@ class InferenceService(object):
|
|
|
156
161
|
wait_for_ready,
|
|
157
162
|
timeout,
|
|
158
163
|
metadata,
|
|
159
|
-
_registered_method=True
|
|
164
|
+
_registered_method=True,
|
|
165
|
+
)
|
|
160
166
|
|
|
161
167
|
@staticmethod
|
|
162
|
-
def GetInferenceResult(
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
168
|
+
def GetInferenceResult(
|
|
169
|
+
request,
|
|
170
|
+
target,
|
|
171
|
+
options=(),
|
|
172
|
+
channel_credentials=None,
|
|
173
|
+
call_credentials=None,
|
|
174
|
+
insecure=False,
|
|
175
|
+
compression=None,
|
|
176
|
+
wait_for_ready=None,
|
|
177
|
+
timeout=None,
|
|
178
|
+
metadata=None,
|
|
179
|
+
):
|
|
172
180
|
return grpc.experimental.unary_unary(
|
|
173
181
|
request,
|
|
174
182
|
target,
|
|
175
|
-
|
|
183
|
+
"/inference.InferenceService/GetInferenceResult",
|
|
176
184
|
infer__pb2.InferenceTxId.SerializeToString,
|
|
177
185
|
infer__pb2.InferenceResult.FromString,
|
|
178
186
|
options,
|
|
@@ -183,4 +191,5 @@ class InferenceService(object):
|
|
|
183
191
|
wait_for_ready,
|
|
184
192
|
timeout,
|
|
185
193
|
metadata,
|
|
186
|
-
_registered_method=True
|
|
194
|
+
_registered_method=True,
|
|
195
|
+
)
|
opengradient/types.py
CHANGED
|
@@ -1,18 +1,21 @@
|
|
|
1
|
+
import time
|
|
1
2
|
from dataclasses import dataclass
|
|
2
|
-
from typing import List, Tuple, Union, Dict, Optional
|
|
3
3
|
from enum import Enum, IntEnum
|
|
4
|
-
import
|
|
4
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
5
6
|
|
|
6
7
|
class CandleOrder(IntEnum):
|
|
7
8
|
ASCENDING = 0
|
|
8
9
|
DESCENDING = 1
|
|
9
10
|
|
|
11
|
+
|
|
10
12
|
class CandleType(IntEnum):
|
|
11
13
|
HIGH = 0
|
|
12
14
|
LOW = 1
|
|
13
15
|
OPEN = 2
|
|
14
16
|
CLOSE = 3
|
|
15
17
|
|
|
18
|
+
|
|
16
19
|
@dataclass
|
|
17
20
|
class HistoricalInputQuery:
|
|
18
21
|
currency_pair: str
|
|
@@ -28,65 +31,74 @@ class HistoricalInputQuery:
|
|
|
28
31
|
self.total_candles,
|
|
29
32
|
self.candle_duration_in_mins,
|
|
30
33
|
int(self.order),
|
|
31
|
-
[int(ct) for ct in self.candle_types]
|
|
34
|
+
[int(ct) for ct in self.candle_types],
|
|
32
35
|
)
|
|
33
36
|
|
|
34
37
|
@classmethod
|
|
35
|
-
def from_dict(cls, data: dict) ->
|
|
38
|
+
def from_dict(cls, data: dict) -> "HistoricalInputQuery":
|
|
36
39
|
"""Create HistoricalInputQuery from dictionary format"""
|
|
37
|
-
order = CandleOrder[data[
|
|
38
|
-
candle_types = [CandleType[ct.upper()] for ct in data[
|
|
39
|
-
|
|
40
|
+
order = CandleOrder[data["order"].upper()]
|
|
41
|
+
candle_types = [CandleType[ct.upper()] for ct in data["candle_types"]]
|
|
42
|
+
|
|
40
43
|
return cls(
|
|
41
|
-
currency_pair=data[
|
|
42
|
-
total_candles=int(data[
|
|
43
|
-
candle_duration_in_mins=int(data[
|
|
44
|
+
currency_pair=data["currency_pair"],
|
|
45
|
+
total_candles=int(data["total_candles"]),
|
|
46
|
+
candle_duration_in_mins=int(data["candle_duration_in_mins"]),
|
|
44
47
|
order=order,
|
|
45
|
-
candle_types=candle_types
|
|
48
|
+
candle_types=candle_types,
|
|
46
49
|
)
|
|
47
50
|
|
|
51
|
+
|
|
48
52
|
@dataclass
|
|
49
53
|
class Number:
|
|
50
54
|
value: int
|
|
51
55
|
decimals: int
|
|
52
56
|
|
|
57
|
+
|
|
53
58
|
@dataclass
|
|
54
59
|
class NumberTensor:
|
|
55
60
|
name: str
|
|
56
61
|
values: List[Tuple[int, int]]
|
|
57
62
|
|
|
63
|
+
|
|
58
64
|
@dataclass
|
|
59
65
|
class StringTensor:
|
|
60
66
|
name: str
|
|
61
67
|
values: List[str]
|
|
62
68
|
|
|
69
|
+
|
|
63
70
|
@dataclass
|
|
64
71
|
class ModelInput:
|
|
65
72
|
numbers: List[NumberTensor]
|
|
66
73
|
strings: List[StringTensor]
|
|
67
74
|
|
|
75
|
+
|
|
68
76
|
class InferenceMode:
|
|
69
77
|
VANILLA = 0
|
|
70
78
|
ZKML = 1
|
|
71
79
|
TEE = 2
|
|
72
80
|
|
|
81
|
+
|
|
73
82
|
class LlmInferenceMode:
|
|
74
83
|
VANILLA = 0
|
|
75
84
|
TEE = 1
|
|
76
85
|
|
|
86
|
+
|
|
77
87
|
@dataclass
|
|
78
88
|
class ModelOutput:
|
|
79
89
|
numbers: List[NumberTensor]
|
|
80
90
|
strings: List[StringTensor]
|
|
81
91
|
is_simulation_result: bool
|
|
82
92
|
|
|
93
|
+
|
|
83
94
|
@dataclass
|
|
84
95
|
class AbiFunction:
|
|
85
96
|
name: str
|
|
86
|
-
inputs: List[Union[str,
|
|
87
|
-
outputs: List[Union[str,
|
|
97
|
+
inputs: List[Union[str, "AbiFunction"]]
|
|
98
|
+
outputs: List[Union[str, "AbiFunction"]]
|
|
88
99
|
state_mutability: str
|
|
89
100
|
|
|
101
|
+
|
|
90
102
|
@dataclass
|
|
91
103
|
class Abi:
|
|
92
104
|
functions: List[AbiFunction]
|
|
@@ -95,32 +107,25 @@ class Abi:
|
|
|
95
107
|
def from_json(cls, abi_json):
|
|
96
108
|
functions = []
|
|
97
109
|
for item in abi_json:
|
|
98
|
-
if item[
|
|
99
|
-
inputs = cls._parse_inputs_outputs(item[
|
|
100
|
-
outputs = cls._parse_inputs_outputs(item[
|
|
101
|
-
functions.append(AbiFunction(
|
|
102
|
-
name=item['name'],
|
|
103
|
-
inputs=inputs,
|
|
104
|
-
outputs=outputs,
|
|
105
|
-
state_mutability=item['stateMutability']
|
|
106
|
-
))
|
|
110
|
+
if item["type"] == "function":
|
|
111
|
+
inputs = cls._parse_inputs_outputs(item["inputs"])
|
|
112
|
+
outputs = cls._parse_inputs_outputs(item["outputs"])
|
|
113
|
+
functions.append(AbiFunction(name=item["name"], inputs=inputs, outputs=outputs, state_mutability=item["stateMutability"]))
|
|
107
114
|
return cls(functions=functions)
|
|
108
115
|
|
|
109
116
|
@staticmethod
|
|
110
117
|
def _parse_inputs_outputs(items):
|
|
111
118
|
result = []
|
|
112
119
|
for item in items:
|
|
113
|
-
if
|
|
114
|
-
result.append(
|
|
115
|
-
name=item[
|
|
116
|
-
|
|
117
|
-
outputs=[],
|
|
118
|
-
state_mutability=''
|
|
119
|
-
))
|
|
120
|
+
if "components" in item:
|
|
121
|
+
result.append(
|
|
122
|
+
AbiFunction(name=item["name"], inputs=Abi._parse_inputs_outputs(item["components"]), outputs=[], state_mutability="")
|
|
123
|
+
)
|
|
120
124
|
else:
|
|
121
125
|
result.append(f"{item['name']}:{item['type']}")
|
|
122
126
|
return result
|
|
123
|
-
|
|
127
|
+
|
|
128
|
+
|
|
124
129
|
class LLM(str, Enum):
|
|
125
130
|
"""Enum for available LLM models"""
|
|
126
131
|
|
|
@@ -130,11 +135,13 @@ class LLM(str, Enum):
|
|
|
130
135
|
HERMES_3_LLAMA_3_1_70B = "NousResearch/Hermes-3-Llama-3.1-70B"
|
|
131
136
|
META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
|
|
132
137
|
|
|
138
|
+
|
|
133
139
|
class TEE_LLM(str, Enum):
|
|
134
140
|
"""Enum for LLM models available for TEE execution"""
|
|
135
141
|
|
|
136
142
|
META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
|
|
137
143
|
|
|
144
|
+
|
|
138
145
|
@dataclass
|
|
139
146
|
class SchedulerParams:
|
|
140
147
|
frequency: int
|
|
@@ -145,10 +152,7 @@ class SchedulerParams:
|
|
|
145
152
|
return int(time.time()) + (self.duration_hours * 60 * 60)
|
|
146
153
|
|
|
147
154
|
@staticmethod
|
|
148
|
-
def from_dict(data: Optional[Dict[str, int]]) -> Optional[
|
|
155
|
+
def from_dict(data: Optional[Dict[str, int]]) -> Optional["SchedulerParams"]:
|
|
149
156
|
if data is None:
|
|
150
157
|
return None
|
|
151
|
-
return SchedulerParams(
|
|
152
|
-
frequency=data.get('frequency', 600),
|
|
153
|
-
duration_hours=data.get('duration_hours', 2)
|
|
154
|
-
)
|
|
158
|
+
return SchedulerParams(frequency=data.get("frequency", 600), duration_hours=data.get("duration_hours", 2))
|