xinference 0.11.1__py3-none-any.whl → 0.11.2.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +30 -0
- xinference/client/restful/restful_client.py +29 -0
- xinference/core/cache_tracker.py +12 -1
- xinference/core/supervisor.py +30 -2
- xinference/core/utils.py +12 -0
- xinference/core/worker.py +4 -1
- xinference/deploy/cmdline.py +126 -0
- xinference/deploy/test/test_cmdline.py +24 -0
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +501 -6
- xinference/model/llm/llm_family.py +84 -10
- xinference/model/llm/llm_family_modelscope.json +198 -7
- xinference/model/llm/memory.py +332 -0
- xinference/model/llm/pytorch/core.py +2 -0
- xinference/model/llm/pytorch/intern_vl.py +347 -0
- xinference/model/llm/utils.py +13 -0
- xinference/model/llm/vllm/core.py +5 -2
- xinference/model/rerank/core.py +23 -1
- xinference/model/utils.py +17 -7
- xinference/thirdparty/deepseek_vl/models/processing_vlm.py +1 -1
- xinference/thirdparty/deepseek_vl/models/siglip_vit.py +2 -2
- xinference/thirdparty/llava/mm_utils.py +3 -2
- xinference/thirdparty/llava/model/llava_arch.py +1 -1
- xinference/thirdparty/omnilmm/chat.py +6 -5
- {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/METADATA +8 -7
- {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/RECORD +31 -29
- {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/LICENSE +0 -0
- {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/WHEEL +0 -0
- {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.1.dist-info → xinference-0.11.2.post1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-05-
|
|
11
|
+
"date": "2024-05-24T19:39:58+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.11.
|
|
14
|
+
"full-revisionid": "ac8f33439c25e6fb05eba79e7932cbbadd068174",
|
|
15
|
+
"version": "0.11.2.post1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -493,6 +493,16 @@ class RESTfulAPI:
|
|
|
493
493
|
else None
|
|
494
494
|
),
|
|
495
495
|
)
|
|
496
|
+
self._router.add_api_route(
|
|
497
|
+
"/v1/cached/list_cached_models",
|
|
498
|
+
self.list_cached_models,
|
|
499
|
+
methods=["GET"],
|
|
500
|
+
dependencies=(
|
|
501
|
+
[Security(self._auth_service, scopes=["models:list"])]
|
|
502
|
+
if self.is_authenticated()
|
|
503
|
+
else None
|
|
504
|
+
),
|
|
505
|
+
)
|
|
496
506
|
|
|
497
507
|
# Clear the global Registry for the MetricsMiddleware, or
|
|
498
508
|
# the MetricsMiddleware will register duplicated metrics if the port
|
|
@@ -688,6 +698,15 @@ class RESTfulAPI:
|
|
|
688
698
|
detail="Invalid input. Please specify the `model_engine` field.",
|
|
689
699
|
)
|
|
690
700
|
|
|
701
|
+
if isinstance(gpu_idx, int):
|
|
702
|
+
gpu_idx = [gpu_idx]
|
|
703
|
+
if gpu_idx:
|
|
704
|
+
if len(gpu_idx) % replica:
|
|
705
|
+
raise HTTPException(
|
|
706
|
+
status_code=400,
|
|
707
|
+
detail="Invalid input. Allocated gpu must be a multiple of replica.",
|
|
708
|
+
)
|
|
709
|
+
|
|
691
710
|
if peft_model_config is not None:
|
|
692
711
|
peft_model_config = PeftModelConfig.from_dict(peft_model_config)
|
|
693
712
|
else:
|
|
@@ -1470,6 +1489,17 @@ class RESTfulAPI:
|
|
|
1470
1489
|
logger.error(e, exc_info=True)
|
|
1471
1490
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1472
1491
|
|
|
1492
|
+
async def list_cached_models(self) -> JSONResponse:
|
|
1493
|
+
try:
|
|
1494
|
+
data = await (await self._get_supervisor_ref()).list_cached_models()
|
|
1495
|
+
return JSONResponse(content=data)
|
|
1496
|
+
except ValueError as re:
|
|
1497
|
+
logger.error(re, exc_info=True)
|
|
1498
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1499
|
+
except Exception as e:
|
|
1500
|
+
logger.error(e, exc_info=True)
|
|
1501
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1502
|
+
|
|
1473
1503
|
async def get_model_events(self, model_uid: str) -> JSONResponse:
|
|
1474
1504
|
try:
|
|
1475
1505
|
event_collector_ref = await self._get_event_collector_ref()
|
|
@@ -1102,6 +1102,35 @@ class Client:
|
|
|
1102
1102
|
response_data = response.json()
|
|
1103
1103
|
return response_data
|
|
1104
1104
|
|
|
1105
|
+
def list_cached_models(self) -> List[Dict[Any, Any]]:
|
|
1106
|
+
"""
|
|
1107
|
+
Get a list of cached models.
|
|
1108
|
+
|
|
1109
|
+
Parameters
|
|
1110
|
+
----------
|
|
1111
|
+
None
|
|
1112
|
+
|
|
1113
|
+
Returns
|
|
1114
|
+
-------
|
|
1115
|
+
List[Dict[Any, Any]]
|
|
1116
|
+
The collection of cached models on the server.
|
|
1117
|
+
|
|
1118
|
+
Raises
|
|
1119
|
+
------
|
|
1120
|
+
RuntimeError
|
|
1121
|
+
Raised when the request fails, including the reason for the failure.
|
|
1122
|
+
"""
|
|
1123
|
+
|
|
1124
|
+
url = f"{self.base_url}/v1/cached/list_cached_models"
|
|
1125
|
+
response = requests.get(url, headers=self._headers)
|
|
1126
|
+
if response.status_code != 200:
|
|
1127
|
+
raise RuntimeError(
|
|
1128
|
+
f"Failed to list cached model, detail: {_get_error_string(response)}"
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
response_data = response.json()
|
|
1132
|
+
return response_data
|
|
1133
|
+
|
|
1105
1134
|
def get_model_registration(
|
|
1106
1135
|
self, model_type: str, model_name: str
|
|
1107
1136
|
) -> Dict[str, Any]:
|
xinference/core/cache_tracker.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
from logging import getLogger
|
|
15
|
-
from typing import Dict, List, Optional
|
|
15
|
+
from typing import Any, Dict, List, Optional
|
|
16
16
|
|
|
17
17
|
import xoscar as xo
|
|
18
18
|
|
|
@@ -100,3 +100,14 @@ class CacheTrackerActor(xo.Actor):
|
|
|
100
100
|
|
|
101
101
|
def get_model_version_count(self, model_name: str) -> int:
|
|
102
102
|
return len(self.get_model_versions(model_name))
|
|
103
|
+
|
|
104
|
+
def list_cached_models(self) -> List[Dict[Any, Any]]:
|
|
105
|
+
cached_models = []
|
|
106
|
+
for model_name, model_versions in self._model_name_to_version_info.items():
|
|
107
|
+
for version_info in model_versions:
|
|
108
|
+
if version_info["cache_status"]:
|
|
109
|
+
ret = version_info.copy()
|
|
110
|
+
ret["model_name"] = model_name
|
|
111
|
+
cached_models.append(ret)
|
|
112
|
+
cached_models = sorted(cached_models, key=lambda x: x["model_name"])
|
|
113
|
+
return cached_models
|
xinference/core/supervisor.py
CHANGED
|
@@ -34,6 +34,7 @@ from ..types import PeftModelConfig
|
|
|
34
34
|
from .metrics import record_metrics
|
|
35
35
|
from .resource import GPUStatus, ResourceStatus
|
|
36
36
|
from .utils import (
|
|
37
|
+
assign_replica_gpu,
|
|
37
38
|
build_replica_model_uid,
|
|
38
39
|
gen_random_string,
|
|
39
40
|
is_valid_model_uid,
|
|
@@ -769,7 +770,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
769
770
|
raise ValueError(
|
|
770
771
|
f"Model is already in the model list, uid: {_replica_model_uid}"
|
|
771
772
|
)
|
|
772
|
-
|
|
773
|
+
replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
|
|
773
774
|
nonlocal model_type
|
|
774
775
|
worker_ref = (
|
|
775
776
|
target_ip_worker_ref
|
|
@@ -789,7 +790,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
789
790
|
n_gpu=n_gpu,
|
|
790
791
|
request_limits=request_limits,
|
|
791
792
|
peft_model_config=peft_model_config,
|
|
792
|
-
gpu_idx=
|
|
793
|
+
gpu_idx=replica_gpu_idx,
|
|
793
794
|
**kwargs,
|
|
794
795
|
)
|
|
795
796
|
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
|
|
@@ -980,6 +981,33 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
980
981
|
and list(self._worker_address_to_worker)[0] == self.address
|
|
981
982
|
)
|
|
982
983
|
|
|
984
|
+
@log_async(logger=logger)
|
|
985
|
+
async def list_cached_models(self) -> List[Dict[str, Any]]:
|
|
986
|
+
cached_models = []
|
|
987
|
+
for worker in self._worker_address_to_worker.values():
|
|
988
|
+
ret = await worker.list_cached_models()
|
|
989
|
+
for model_version in ret:
|
|
990
|
+
model_name = model_version.get("model_name", None)
|
|
991
|
+
model_format = model_version.get("model_format", None)
|
|
992
|
+
model_size_in_billions = model_version.get(
|
|
993
|
+
"model_size_in_billions", None
|
|
994
|
+
)
|
|
995
|
+
quantizations = model_version.get("quantization", None)
|
|
996
|
+
re_dict = model_version.get("model_file_location", None)
|
|
997
|
+
actor_ip_address, path = next(iter(re_dict.items()))
|
|
998
|
+
|
|
999
|
+
cache_entry = {
|
|
1000
|
+
"model_name": model_name,
|
|
1001
|
+
"model_format": model_format,
|
|
1002
|
+
"model_size_in_billions": model_size_in_billions,
|
|
1003
|
+
"quantizations": quantizations,
|
|
1004
|
+
"path": path,
|
|
1005
|
+
"Actor IP Address": actor_ip_address,
|
|
1006
|
+
}
|
|
1007
|
+
|
|
1008
|
+
cached_models.append(cache_entry)
|
|
1009
|
+
return cached_models
|
|
1010
|
+
|
|
983
1011
|
@log_async(logger=logger)
|
|
984
1012
|
async def add_worker(self, worker_address: str):
|
|
985
1013
|
from .worker import WorkerActor
|
xinference/core/utils.py
CHANGED
|
@@ -191,3 +191,15 @@ def get_nvidia_gpu_info() -> Dict:
|
|
|
191
191
|
nvmlShutdown()
|
|
192
192
|
except:
|
|
193
193
|
pass
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def assign_replica_gpu(
|
|
197
|
+
_replica_model_uid: str, gpu_idx: Union[int, List[int]]
|
|
198
|
+
) -> List[int]:
|
|
199
|
+
model_uid, replica, rep_id = parse_replica_model_uid(_replica_model_uid)
|
|
200
|
+
rep_id, replica = int(rep_id), int(replica)
|
|
201
|
+
if isinstance(gpu_idx, int):
|
|
202
|
+
gpu_idx = [gpu_idx]
|
|
203
|
+
if isinstance(gpu_idx, list) and gpu_idx:
|
|
204
|
+
return gpu_idx[rep_id::replica]
|
|
205
|
+
return gpu_idx
|
xinference/core/worker.py
CHANGED
|
@@ -456,7 +456,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
456
456
|
) -> Tuple[str, List[str]]:
|
|
457
457
|
env = {}
|
|
458
458
|
devices = []
|
|
459
|
-
env_name = get_available_device_env_name()
|
|
459
|
+
env_name = get_available_device_env_name() or "CUDA_VISIBLE_DEVICES"
|
|
460
460
|
if gpu_idx is None:
|
|
461
461
|
if isinstance(n_gpu, int) or (n_gpu == "auto" and gpu_count() > 0):
|
|
462
462
|
# Currently, n_gpu=auto means using 1 GPU
|
|
@@ -781,6 +781,9 @@ class WorkerActor(xo.StatelessActor):
|
|
|
781
781
|
except asyncio.CancelledError: # pragma: no cover
|
|
782
782
|
break
|
|
783
783
|
|
|
784
|
+
async def list_cached_models(self) -> List[Dict[Any, Any]]:
|
|
785
|
+
return self._cache_tracker_ref.list_cached_models()
|
|
786
|
+
|
|
784
787
|
@staticmethod
|
|
785
788
|
def record_metrics(name, op, kwargs):
|
|
786
789
|
record_metrics(name, op, kwargs)
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -570,6 +570,44 @@ def list_model_registrations(
|
|
|
570
570
|
raise NotImplementedError(f"List {model_type} is not implemented.")
|
|
571
571
|
|
|
572
572
|
|
|
573
|
+
@cli.command("cached", help="List all cached models in Xinference.")
|
|
574
|
+
@click.option(
|
|
575
|
+
"--endpoint",
|
|
576
|
+
"-e",
|
|
577
|
+
type=str,
|
|
578
|
+
help="Xinference endpoint.",
|
|
579
|
+
)
|
|
580
|
+
@click.option(
|
|
581
|
+
"--api-key",
|
|
582
|
+
"-ak",
|
|
583
|
+
default=None,
|
|
584
|
+
type=str,
|
|
585
|
+
help="Api-Key for access xinference api with authorization.",
|
|
586
|
+
)
|
|
587
|
+
def list_cached_models(
|
|
588
|
+
endpoint: Optional[str],
|
|
589
|
+
api_key: Optional[str],
|
|
590
|
+
):
|
|
591
|
+
from tabulate import tabulate
|
|
592
|
+
|
|
593
|
+
endpoint = get_endpoint(endpoint)
|
|
594
|
+
client = RESTfulClient(base_url=endpoint, api_key=api_key)
|
|
595
|
+
if api_key is None:
|
|
596
|
+
client._set_token(get_stored_token(endpoint, client))
|
|
597
|
+
|
|
598
|
+
cached_models = client.list_cached_models()
|
|
599
|
+
|
|
600
|
+
print("cached_model: ")
|
|
601
|
+
headers = list(cached_models[0].keys())
|
|
602
|
+
table_data = []
|
|
603
|
+
for model in cached_models:
|
|
604
|
+
row_data = [
|
|
605
|
+
str(value) if value is not None else "-" for value in model.values()
|
|
606
|
+
]
|
|
607
|
+
table_data.append(row_data)
|
|
608
|
+
print(tabulate(table_data, headers=headers, tablefmt="pretty"))
|
|
609
|
+
|
|
610
|
+
|
|
573
611
|
@cli.command(
|
|
574
612
|
"launch",
|
|
575
613
|
help="Launch a model with the Xinference framework with the given parameters.",
|
|
@@ -1368,5 +1406,93 @@ def query_engine_by_model_name(
|
|
|
1368
1406
|
)
|
|
1369
1407
|
|
|
1370
1408
|
|
|
1409
|
+
@cli.command(
|
|
1410
|
+
"cal-model-mem",
|
|
1411
|
+
help="calculate gpu mem usage with specified model size and context_length",
|
|
1412
|
+
)
|
|
1413
|
+
@click.option(
|
|
1414
|
+
"--model-name",
|
|
1415
|
+
"-n",
|
|
1416
|
+
type=str,
|
|
1417
|
+
help="The model name is optional.\
|
|
1418
|
+
If provided, fetch model config from huggingface/modelscope;\
|
|
1419
|
+
If not specified, use default model layer to estimate.",
|
|
1420
|
+
)
|
|
1421
|
+
@click.option(
|
|
1422
|
+
"--size-in-billions",
|
|
1423
|
+
"-s",
|
|
1424
|
+
type=str,
|
|
1425
|
+
required=True,
|
|
1426
|
+
help="Specify the model size in billions of parameters. Format accept 1_8 and 1.8",
|
|
1427
|
+
)
|
|
1428
|
+
@click.option(
|
|
1429
|
+
"--model-format",
|
|
1430
|
+
"-f",
|
|
1431
|
+
type=str,
|
|
1432
|
+
required=True,
|
|
1433
|
+
help="Specify the format of the model, e.g. pytorch, ggmlv3, etc.",
|
|
1434
|
+
)
|
|
1435
|
+
@click.option(
|
|
1436
|
+
"--quantization",
|
|
1437
|
+
"-q",
|
|
1438
|
+
type=str,
|
|
1439
|
+
default=None,
|
|
1440
|
+
help="Define the quantization settings for the model.",
|
|
1441
|
+
)
|
|
1442
|
+
@click.option(
|
|
1443
|
+
"--context-length",
|
|
1444
|
+
"-c",
|
|
1445
|
+
type=int,
|
|
1446
|
+
required=True,
|
|
1447
|
+
help="Specify the context length",
|
|
1448
|
+
)
|
|
1449
|
+
@click.option(
|
|
1450
|
+
"--kv-cache-dtype",
|
|
1451
|
+
type=int,
|
|
1452
|
+
default=16,
|
|
1453
|
+
help="Specified the kv_cache_dtype, one of: 8, 16, 32",
|
|
1454
|
+
)
|
|
1455
|
+
def cal_model_mem(
|
|
1456
|
+
model_name: Optional[str],
|
|
1457
|
+
size_in_billions: str,
|
|
1458
|
+
model_format: str,
|
|
1459
|
+
quantization: Optional[str],
|
|
1460
|
+
context_length: int,
|
|
1461
|
+
kv_cache_dtype: int,
|
|
1462
|
+
):
|
|
1463
|
+
if kv_cache_dtype not in [8, 16, 32]:
|
|
1464
|
+
print("Invalid kv_cache_dtype:", kv_cache_dtype)
|
|
1465
|
+
os._exit(1)
|
|
1466
|
+
|
|
1467
|
+
import math
|
|
1468
|
+
|
|
1469
|
+
from ..model.llm.llm_family import convert_model_size_to_float
|
|
1470
|
+
from ..model.llm.memory import estimate_llm_gpu_memory
|
|
1471
|
+
|
|
1472
|
+
mem_info = estimate_llm_gpu_memory(
|
|
1473
|
+
model_size_in_billions=size_in_billions,
|
|
1474
|
+
quantization=quantization,
|
|
1475
|
+
context_length=context_length,
|
|
1476
|
+
model_format=model_format,
|
|
1477
|
+
model_name=model_name,
|
|
1478
|
+
kv_cache_dtype=kv_cache_dtype,
|
|
1479
|
+
)
|
|
1480
|
+
if mem_info is None:
|
|
1481
|
+
print("The Specified model parameters is not match: `%s`" % model_name)
|
|
1482
|
+
os._exit(1)
|
|
1483
|
+
total_mem_g = math.ceil(mem_info.total / 1024.0)
|
|
1484
|
+
print("model_name:", model_name)
|
|
1485
|
+
print("kv_cache_dtype:", kv_cache_dtype)
|
|
1486
|
+
print("model size: %.1f B" % (convert_model_size_to_float(size_in_billions)))
|
|
1487
|
+
print("quant: %s" % (quantization))
|
|
1488
|
+
print("context: %d" % (context_length))
|
|
1489
|
+
print("gpu mem usage:")
|
|
1490
|
+
print(" model mem: %d MB" % (mem_info.model_mem))
|
|
1491
|
+
print(" kv_cache: %d MB" % (mem_info.kv_cache_mem))
|
|
1492
|
+
print(" overhead: %d MB" % (mem_info.overhead))
|
|
1493
|
+
print(" active: %d MB" % (mem_info.activation_mem))
|
|
1494
|
+
print(" total: %d MB (%d GB)" % (mem_info.total, total_mem_g))
|
|
1495
|
+
|
|
1496
|
+
|
|
1371
1497
|
if __name__ == "__main__":
|
|
1372
1498
|
cli()
|
|
@@ -19,6 +19,7 @@ from click.testing import CliRunner
|
|
|
19
19
|
|
|
20
20
|
from ...client import Client
|
|
21
21
|
from ..cmdline import (
|
|
22
|
+
list_cached_models,
|
|
22
23
|
list_model_registrations,
|
|
23
24
|
model_chat,
|
|
24
25
|
model_generate,
|
|
@@ -278,3 +279,26 @@ def test_rotate_logs(setup_with_file_logging):
|
|
|
278
279
|
with open(log_file, "r") as f:
|
|
279
280
|
content = f.read()
|
|
280
281
|
assert len(content) > 0
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def test_list_cached_models(setup):
|
|
285
|
+
endpoint, _ = setup
|
|
286
|
+
runner = CliRunner()
|
|
287
|
+
|
|
288
|
+
result = runner.invoke(
|
|
289
|
+
list_cached_models,
|
|
290
|
+
[
|
|
291
|
+
"--endpoint",
|
|
292
|
+
endpoint,
|
|
293
|
+
],
|
|
294
|
+
)
|
|
295
|
+
assert result.exit_code == 0
|
|
296
|
+
assert "cached_model: " in result.stdout
|
|
297
|
+
|
|
298
|
+
# check if the output is in tabular format
|
|
299
|
+
assert "model_name" in result.stdout
|
|
300
|
+
assert "model_format" in result.stdout
|
|
301
|
+
assert "model_size_in_billions" in result.stdout
|
|
302
|
+
assert "quantizations" in result.stdout
|
|
303
|
+
assert "path" in result.stdout
|
|
304
|
+
assert "Actor IP Address" in result.stdout
|
xinference/model/llm/__init__.py
CHANGED
|
@@ -116,6 +116,7 @@ def _install():
|
|
|
116
116
|
from .pytorch.core import PytorchChatModel, PytorchModel
|
|
117
117
|
from .pytorch.deepseek_vl import DeepSeekVLChatModel
|
|
118
118
|
from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
|
|
119
|
+
from .pytorch.intern_vl import InternVLChatModel
|
|
119
120
|
from .pytorch.internlm2 import Internlm2PytorchChatModel
|
|
120
121
|
from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
|
|
121
122
|
from .pytorch.qwen_vl import QwenVLChatModel
|
|
@@ -156,6 +157,7 @@ def _install():
|
|
|
156
157
|
QwenVLChatModel,
|
|
157
158
|
YiVLChatModel,
|
|
158
159
|
DeepSeekVLChatModel,
|
|
160
|
+
InternVLChatModel,
|
|
159
161
|
PytorchModel,
|
|
160
162
|
]
|
|
161
163
|
)
|