kaggle 1.7.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kaggle/LICENSE +201 -0
- kaggle/__init__.py +6 -0
- kaggle/api/__init__.py +0 -0
- kaggle/api/kaggle_api.py +614 -0
- kaggle/api/kaggle_api_extended.py +4657 -0
- kaggle/cli.py +1606 -0
- kaggle/configuration.py +206 -0
- kaggle/models/__init__.py +0 -0
- kaggle/models/api_blob_type.py +4 -0
- kaggle/models/dataset_column.py +228 -0
- kaggle/models/dataset_new_request.py +385 -0
- kaggle/models/dataset_new_version_request.py +287 -0
- kaggle/models/dataset_update_settings_request.py +310 -0
- kaggle/models/kaggle_models_extended.py +276 -0
- kaggle/models/kernel_push_request.py +556 -0
- kaggle/models/model_instance_new_version_request.py +145 -0
- kaggle/models/model_instance_update_request.py +351 -0
- kaggle/models/model_new_instance_request.py +417 -0
- kaggle/models/model_new_request.py +314 -0
- kaggle/models/model_update_request.py +282 -0
- kaggle/models/start_blob_upload_request.py +232 -0
- kaggle/models/start_blob_upload_response.py +137 -0
- kaggle/models/upload_file.py +169 -0
- kaggle/test/__init__.py +0 -0
- kaggle/test/test_authenticate.py +43 -0
- kaggle-1.7.3b1.dist-info/METADATA +348 -0
- kaggle-1.7.3b1.dist-info/RECORD +89 -0
- kaggle-1.7.3b1.dist-info/WHEEL +4 -0
- kaggle-1.7.3b1.dist-info/entry_points.txt +2 -0
- kaggle-1.7.3b1.dist-info/licenses/LICENSE.txt +201 -0
- kagglesdk/LICENSE +201 -0
- kagglesdk/__init__.py +2 -0
- kagglesdk/admin/__init__.py +0 -0
- kagglesdk/admin/services/__init__.py +0 -0
- kagglesdk/admin/services/inbox_file_service.py +22 -0
- kagglesdk/admin/types/__init__.py +0 -0
- kagglesdk/admin/types/inbox_file_service.py +74 -0
- kagglesdk/blobs/__init__.py +0 -0
- kagglesdk/blobs/services/__init__.py +0 -0
- kagglesdk/blobs/services/blob_api_service.py +25 -0
- kagglesdk/blobs/types/__init__.py +0 -0
- kagglesdk/blobs/types/blob_api_service.py +177 -0
- kagglesdk/common/__init__.py +0 -0
- kagglesdk/common/types/__init__.py +0 -0
- kagglesdk/common/types/file_download.py +102 -0
- kagglesdk/common/types/http_redirect.py +105 -0
- kagglesdk/competitions/__init__.py +0 -0
- kagglesdk/competitions/services/__init__.py +0 -0
- kagglesdk/competitions/services/competition_api_service.py +129 -0
- kagglesdk/competitions/types/__init__.py +0 -0
- kagglesdk/competitions/types/competition_api_service.py +1874 -0
- kagglesdk/competitions/types/competition_enums.py +53 -0
- kagglesdk/competitions/types/submission_status.py +9 -0
- kagglesdk/datasets/__init__.py +0 -0
- kagglesdk/datasets/services/__init__.py +0 -0
- kagglesdk/datasets/services/dataset_api_service.py +170 -0
- kagglesdk/datasets/types/__init__.py +0 -0
- kagglesdk/datasets/types/dataset_api_service.py +2777 -0
- kagglesdk/datasets/types/dataset_enums.py +82 -0
- kagglesdk/datasets/types/dataset_types.py +646 -0
- kagglesdk/education/__init__.py +0 -0
- kagglesdk/education/services/__init__.py +0 -0
- kagglesdk/education/services/education_api_service.py +19 -0
- kagglesdk/education/types/__init__.py +0 -0
- kagglesdk/education/types/education_api_service.py +248 -0
- kagglesdk/education/types/education_service.py +139 -0
- kagglesdk/kaggle_client.py +66 -0
- kagglesdk/kaggle_env.py +42 -0
- kagglesdk/kaggle_http_client.py +316 -0
- kagglesdk/kaggle_object.py +293 -0
- kagglesdk/kernels/__init__.py +0 -0
- kagglesdk/kernels/services/__init__.py +0 -0
- kagglesdk/kernels/services/kernels_api_service.py +109 -0
- kagglesdk/kernels/types/__init__.py +0 -0
- kagglesdk/kernels/types/kernels_api_service.py +1951 -0
- kagglesdk/kernels/types/kernels_enums.py +33 -0
- kagglesdk/models/__init__.py +0 -0
- kagglesdk/models/services/__init__.py +0 -0
- kagglesdk/models/services/model_api_service.py +255 -0
- kagglesdk/models/services/model_service.py +19 -0
- kagglesdk/models/types/__init__.py +0 -0
- kagglesdk/models/types/model_api_service.py +3719 -0
- kagglesdk/models/types/model_enums.py +60 -0
- kagglesdk/models/types/model_service.py +275 -0
- kagglesdk/models/types/model_types.py +286 -0
- kagglesdk/test/test_client.py +45 -0
- kagglesdk/users/__init__.py +0 -0
- kagglesdk/users/types/__init__.py +0 -0
- kagglesdk/users/types/users_enums.py +22 -0
|
@@ -0,0 +1,4657 @@
|
|
|
1
|
+
#!/usr/bin/python
|
|
2
|
+
#
|
|
3
|
+
# Copyright 2024 Kaggle Inc
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
# coding=utf-8
|
|
18
|
+
from __future__ import print_function
|
|
19
|
+
|
|
20
|
+
import csv
|
|
21
|
+
import io
|
|
22
|
+
import os
|
|
23
|
+
import shutil
|
|
24
|
+
import sys
|
|
25
|
+
import tarfile
|
|
26
|
+
import tempfile
|
|
27
|
+
import time
|
|
28
|
+
import zipfile
|
|
29
|
+
from os.path import expanduser
|
|
30
|
+
from random import random
|
|
31
|
+
|
|
32
|
+
import bleach
|
|
33
|
+
import requests
|
|
34
|
+
import urllib3.exceptions as urllib3_exceptions
|
|
35
|
+
from requests import RequestException
|
|
36
|
+
|
|
37
|
+
from kaggle.models.kaggle_models_extended import ResumableUploadResult, File
|
|
38
|
+
|
|
39
|
+
from requests.adapters import HTTPAdapter
|
|
40
|
+
from slugify import slugify
|
|
41
|
+
from tqdm import tqdm
|
|
42
|
+
from urllib3.util.retry import Retry
|
|
43
|
+
from google.protobuf import field_mask_pb2
|
|
44
|
+
|
|
45
|
+
from kaggle.configuration import Configuration
|
|
46
|
+
from kagglesdk import KaggleClient, KaggleEnv
|
|
47
|
+
from kagglesdk.admin.types.inbox_file_service import CreateInboxFileRequest
|
|
48
|
+
from kagglesdk.blobs.types.blob_api_service import ApiStartBlobUploadRequest, \
|
|
49
|
+
ApiStartBlobUploadResponse, ApiBlobType
|
|
50
|
+
from kagglesdk.competitions.types.competition_api_service import *
|
|
51
|
+
from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, \
|
|
52
|
+
ApiListDatasetFilesRequest, \
|
|
53
|
+
ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, \
|
|
54
|
+
ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \
|
|
55
|
+
ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, \
|
|
56
|
+
ApiDatasetNewFile, ApiUpdateDatasetMetadataRequest, \
|
|
57
|
+
ApiGetDatasetMetadataRequest, ApiListDatasetFilesResponse, ApiDatasetFile
|
|
58
|
+
from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, \
|
|
59
|
+
DatasetSortBy, DatasetFileTypeGroup, DatasetLicenseGroup
|
|
60
|
+
from kagglesdk.datasets.types.dataset_types import DatasetSettings, \
|
|
61
|
+
SettingsLicense, DatasetCollaborator
|
|
62
|
+
from kagglesdk.kernels.types.kernels_api_service import ApiListKernelsRequest, \
|
|
63
|
+
ApiListKernelFilesRequest, ApiSaveKernelRequest, ApiGetKernelRequest, \
|
|
64
|
+
ApiListKernelSessionOutputRequest, ApiGetKernelSessionStatusRequest
|
|
65
|
+
from kagglesdk.kernels.types.kernels_enums import KernelsListSortType, \
|
|
66
|
+
KernelsListViewType
|
|
67
|
+
from kagglesdk.models.types.model_api_service import ApiListModelsRequest, \
|
|
68
|
+
ApiCreateModelRequest, ApiGetModelRequest, ApiDeleteModelRequest, \
|
|
69
|
+
ApiUpdateModelRequest, ApiGetModelInstanceRequest, \
|
|
70
|
+
ApiCreateModelInstanceRequest, ApiCreateModelInstanceRequestBody, \
|
|
71
|
+
ApiListModelInstanceVersionFilesRequest, ApiUpdateModelInstanceRequest, \
|
|
72
|
+
ApiDeleteModelInstanceRequest, ApiCreateModelInstanceVersionRequest, \
|
|
73
|
+
ApiCreateModelInstanceVersionRequestBody, \
|
|
74
|
+
ApiDownloadModelInstanceVersionRequest, ApiDeleteModelInstanceVersionRequest
|
|
75
|
+
from kagglesdk.models.types.model_enums import ListModelsOrderBy, \
|
|
76
|
+
ModelInstanceType, ModelFramework
|
|
77
|
+
from ..models.dataset_column import DatasetColumn
|
|
78
|
+
from ..models.upload_file import UploadFile
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class DirectoryArchive(object):
|
|
82
|
+
|
|
83
|
+
def __init__(self, fullpath, format):
|
|
84
|
+
self._fullpath = fullpath
|
|
85
|
+
self._format = format
|
|
86
|
+
self.name = None
|
|
87
|
+
self.path = None
|
|
88
|
+
|
|
89
|
+
def __enter__(self):
|
|
90
|
+
self._temp_dir = tempfile.mkdtemp()
|
|
91
|
+
_, dir_name = os.path.split(self._fullpath)
|
|
92
|
+
self.path = shutil.make_archive(
|
|
93
|
+
os.path.join(self._temp_dir, dir_name), self._format, self._fullpath)
|
|
94
|
+
_, self.name = os.path.split(self.path)
|
|
95
|
+
return self
|
|
96
|
+
|
|
97
|
+
def __exit__(self, *args):
|
|
98
|
+
shutil.rmtree(self._temp_dir)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ResumableUploadContext(object):
|
|
102
|
+
|
|
103
|
+
def __init__(self, no_resume=False):
|
|
104
|
+
self.no_resume = no_resume
|
|
105
|
+
self._temp_dir = os.path.join(tempfile.gettempdir(), '.kaggle/uploads')
|
|
106
|
+
self._file_uploads = []
|
|
107
|
+
|
|
108
|
+
def __enter__(self):
|
|
109
|
+
if self.no_resume:
|
|
110
|
+
return
|
|
111
|
+
self._create_temp_dir()
|
|
112
|
+
return self
|
|
113
|
+
|
|
114
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
115
|
+
if self.no_resume:
|
|
116
|
+
return
|
|
117
|
+
if exc_type is not None:
|
|
118
|
+
# Don't delete the upload file info when there is an error
|
|
119
|
+
# to give it a chance to retry/resume on the next invocation.
|
|
120
|
+
return
|
|
121
|
+
for file_upload in self._file_uploads:
|
|
122
|
+
file_upload.cleanup()
|
|
123
|
+
|
|
124
|
+
def get_upload_info_file_path(self, path):
|
|
125
|
+
return os.path.join(
|
|
126
|
+
self._temp_dir,
|
|
127
|
+
'%s.json' % path.replace(os.path.sep, '_').replace(':', '_'))
|
|
128
|
+
|
|
129
|
+
def new_resumable_file_upload(self, path, start_blob_upload_request):
|
|
130
|
+
file_upload = ResumableFileUpload(path, start_blob_upload_request, self)
|
|
131
|
+
self._file_uploads.append(file_upload)
|
|
132
|
+
file_upload.load()
|
|
133
|
+
return file_upload
|
|
134
|
+
|
|
135
|
+
def _create_temp_dir(self):
|
|
136
|
+
try:
|
|
137
|
+
os.makedirs(self._temp_dir)
|
|
138
|
+
except FileExistsError:
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class ResumableFileUpload(object):
|
|
143
|
+
# Reference: https://cloud.google.com/storage/docs/resumable-uploads
|
|
144
|
+
# A resumable upload must be completed within a week of being initiated
|
|
145
|
+
RESUMABLE_UPLOAD_EXPIRY_SECONDS = 6 * 24 * 3600
|
|
146
|
+
|
|
147
|
+
def __init__(self, path, start_blob_upload_request, context):
|
|
148
|
+
self.path = path
|
|
149
|
+
self.start_blob_upload_request = start_blob_upload_request
|
|
150
|
+
self.context = context
|
|
151
|
+
self.timestamp = int(time.time())
|
|
152
|
+
self.start_blob_upload_response = None
|
|
153
|
+
self.can_resume = False
|
|
154
|
+
self.upload_complete = False
|
|
155
|
+
if self.context.no_resume:
|
|
156
|
+
return
|
|
157
|
+
self._upload_info_file_path = self.context.get_upload_info_file_path(path)
|
|
158
|
+
|
|
159
|
+
def get_token(self):
|
|
160
|
+
if self.upload_complete:
|
|
161
|
+
return self.start_blob_upload_response.token
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
def load(self):
|
|
165
|
+
if self.context.no_resume:
|
|
166
|
+
return
|
|
167
|
+
self._load_previous_if_any()
|
|
168
|
+
|
|
169
|
+
def _load_previous_if_any(self):
|
|
170
|
+
if not os.path.exists(self._upload_info_file_path):
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
with io.open(self._upload_info_file_path, 'r') as f:
|
|
175
|
+
previous = ResumableFileUpload.from_dict(json.load(f), self.context)
|
|
176
|
+
if self._is_previous_valid(previous):
|
|
177
|
+
self.start_blob_upload_response = previous.start_blob_upload_response
|
|
178
|
+
self.timestamp = previous.timestamp
|
|
179
|
+
self.can_resume = True
|
|
180
|
+
except Exception as e:
|
|
181
|
+
print('Error while trying to load upload info:', e)
|
|
182
|
+
|
|
183
|
+
def _is_previous_valid(self, previous):
|
|
184
|
+
return previous.path == self.path and \
|
|
185
|
+
previous.start_blob_upload_request == self.start_blob_upload_request and \
|
|
186
|
+
previous.timestamp > time.time() - ResumableFileUpload.RESUMABLE_UPLOAD_EXPIRY_SECONDS
|
|
187
|
+
|
|
188
|
+
def upload_initiated(self, start_blob_upload_response):
|
|
189
|
+
if self.context.no_resume:
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
self.start_blob_upload_response = start_blob_upload_response
|
|
193
|
+
with io.open(self._upload_info_file_path, 'w') as f:
|
|
194
|
+
json.dump(self.to_dict(), f, indent=True)
|
|
195
|
+
|
|
196
|
+
def upload_completed(self):
|
|
197
|
+
if self.context.no_resume:
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
self.upload_complete = True
|
|
201
|
+
self._save()
|
|
202
|
+
|
|
203
|
+
def _save(self):
|
|
204
|
+
with io.open(self._upload_info_file_path, 'w') as f:
|
|
205
|
+
json.dump(self.to_dict(), f, indent=True)
|
|
206
|
+
|
|
207
|
+
def cleanup(self):
|
|
208
|
+
if self.context.no_resume:
|
|
209
|
+
return
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
os.remove(self._upload_info_file_path)
|
|
213
|
+
except OSError:
|
|
214
|
+
pass
|
|
215
|
+
|
|
216
|
+
def to_dict(self):
|
|
217
|
+
return {
|
|
218
|
+
'path':
|
|
219
|
+
self.path,
|
|
220
|
+
'start_blob_upload_request':
|
|
221
|
+
self.start_blob_upload_request.to_dict(),
|
|
222
|
+
'timestamp':
|
|
223
|
+
self.timestamp,
|
|
224
|
+
'start_blob_upload_response':
|
|
225
|
+
self.start_blob_upload_response.to_dict()
|
|
226
|
+
if self.start_blob_upload_response is not None else None,
|
|
227
|
+
'upload_complete':
|
|
228
|
+
self.upload_complete,
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
def from_dict(other, context):
|
|
232
|
+
req = ApiStartBlobUploadRequest()
|
|
233
|
+
req.from_dict(other['start_blob_upload_request'])
|
|
234
|
+
new = ResumableFileUpload(
|
|
235
|
+
other['path'],
|
|
236
|
+
ApiStartBlobUploadRequest(**other['start_blob_upload_request']),
|
|
237
|
+
context)
|
|
238
|
+
new.timestamp = other.get('timestamp')
|
|
239
|
+
start_blob_upload_response = other.get('start_blob_upload_response')
|
|
240
|
+
if start_blob_upload_response is not None:
|
|
241
|
+
new.start_blob_upload_response = ApiStartBlobUploadResponse(
|
|
242
|
+
**start_blob_upload_response)
|
|
243
|
+
new.upload_complete = other.get('upload_complete') or False
|
|
244
|
+
return new
|
|
245
|
+
|
|
246
|
+
def to_str(self):
|
|
247
|
+
return str(self.to_dict())
|
|
248
|
+
|
|
249
|
+
def __repr__(self):
|
|
250
|
+
return self.to_str()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class KaggleApi:
|
|
254
|
+
__version__ = '1.7.3b1'
|
|
255
|
+
|
|
256
|
+
CONFIG_NAME_PROXY = 'proxy'
|
|
257
|
+
CONFIG_NAME_COMPETITION = 'competition'
|
|
258
|
+
CONFIG_NAME_PATH = 'path'
|
|
259
|
+
CONFIG_NAME_USER = 'username'
|
|
260
|
+
CONFIG_NAME_KEY = 'key'
|
|
261
|
+
CONFIG_NAME_SSL_CA_CERT = 'ssl_ca_cert'
|
|
262
|
+
|
|
263
|
+
HEADER_API_VERSION = 'X-Kaggle-ApiVersion'
|
|
264
|
+
DATASET_METADATA_FILE = 'dataset-metadata.json'
|
|
265
|
+
OLD_DATASET_METADATA_FILE = 'datapackage.json'
|
|
266
|
+
KERNEL_METADATA_FILE = 'kernel-metadata.json'
|
|
267
|
+
MODEL_METADATA_FILE = 'model-metadata.json'
|
|
268
|
+
MODEL_INSTANCE_METADATA_FILE = 'model-instance-metadata.json'
|
|
269
|
+
MAX_NUM_INBOX_FILES_TO_UPLOAD = 1000
|
|
270
|
+
MAX_UPLOAD_RESUME_ATTEMPTS = 10
|
|
271
|
+
|
|
272
|
+
config_dir = os.environ.get('KAGGLE_CONFIG_DIR')
|
|
273
|
+
|
|
274
|
+
if not config_dir:
|
|
275
|
+
config_dir = os.path.join(expanduser('~'), '.kaggle')
|
|
276
|
+
# Use ~/.kaggle if it already exists for backwards compatibility,
|
|
277
|
+
# otherwise follow XDG base directory specification
|
|
278
|
+
if sys.platform.startswith('linux') and not os.path.exists(config_dir):
|
|
279
|
+
config_dir = os.path.join((os.environ.get('XDG_CONFIG_HOME') or
|
|
280
|
+
os.path.join(expanduser('~'), '.config')),
|
|
281
|
+
'kaggle')
|
|
282
|
+
|
|
283
|
+
if not os.path.exists(config_dir):
|
|
284
|
+
os.makedirs(config_dir)
|
|
285
|
+
|
|
286
|
+
config_file = 'kaggle.json'
|
|
287
|
+
config = os.path.join(config_dir, config_file)
|
|
288
|
+
config_values = {}
|
|
289
|
+
already_printed_version_warning = False
|
|
290
|
+
|
|
291
|
+
args = {} # DEBUG Add --local to use localhost
|
|
292
|
+
if os.environ.get('KAGGLE_API_ENVIRONMENT') == 'LOCALHOST':
|
|
293
|
+
args = {'--local'}
|
|
294
|
+
|
|
295
|
+
# Kernels valid types
|
|
296
|
+
valid_push_kernel_types = ['script', 'notebook']
|
|
297
|
+
valid_push_language_types = ['python', 'r', 'rmarkdown']
|
|
298
|
+
valid_push_pinning_types = ['original', 'latest']
|
|
299
|
+
valid_list_languages = ['all', 'python', 'r', 'sqlite', 'julia']
|
|
300
|
+
valid_list_kernel_types = ['all', 'script', 'notebook']
|
|
301
|
+
valid_list_output_types = ['all', 'visualization', 'data']
|
|
302
|
+
valid_list_sort_by = [
|
|
303
|
+
'hotness', 'commentCount', 'dateCreated', 'dateRun', 'relevance',
|
|
304
|
+
'scoreAscending', 'scoreDescending', 'viewCount', 'voteCount'
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
# Competitions valid types
|
|
308
|
+
valid_competition_groups = [
|
|
309
|
+
'general', 'entered', 'community', 'hosted', 'unlaunched',
|
|
310
|
+
'unlaunched_community'
|
|
311
|
+
]
|
|
312
|
+
valid_competition_categories = [
|
|
313
|
+
'all', 'featured', 'research', 'recruitment', 'gettingStarted', 'masters',
|
|
314
|
+
'playground'
|
|
315
|
+
]
|
|
316
|
+
valid_competition_sort_by = [
|
|
317
|
+
'grouped', 'best', 'prize', 'earliestDeadline', 'latestDeadline',
|
|
318
|
+
'numberOfTeams', 'relevance', 'recentlyCreated'
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
# Datasets valid types
|
|
322
|
+
valid_dataset_file_types = ['all', 'csv', 'sqlite', 'json', 'bigQuery']
|
|
323
|
+
valid_dataset_license_names = ['all', 'cc', 'gpl', 'odb', 'other']
|
|
324
|
+
valid_dataset_sort_bys = [
|
|
325
|
+
'hottest', 'votes', 'updated', 'active', 'published'
|
|
326
|
+
]
|
|
327
|
+
|
|
328
|
+
# Models valid types
|
|
329
|
+
valid_model_sort_bys = [
|
|
330
|
+
'hotness', 'downloadCount', 'voteCount', 'notebookCount', 'createTime'
|
|
331
|
+
]
|
|
332
|
+
|
|
333
|
+
# Command prefixes that are valid without authentication.
|
|
334
|
+
command_prefixes_allowing_anonymous_access = ('datasets download',
|
|
335
|
+
'datasets files')
|
|
336
|
+
|
|
337
|
+
# Attributes
|
|
338
|
+
competition_fields = [
|
|
339
|
+
'ref', 'deadline', 'category', 'reward', 'teamCount', 'userHasEntered'
|
|
340
|
+
]
|
|
341
|
+
submission_fields = [
|
|
342
|
+
'fileName', 'date', 'description', 'status', 'publicScore', 'privateScore'
|
|
343
|
+
]
|
|
344
|
+
competition_file_fields = ['name', 'totalBytes', 'creationDate']
|
|
345
|
+
competition_file_labels = ['name', 'size', 'creationDate']
|
|
346
|
+
competition_leaderboard_fields = [
|
|
347
|
+
'teamId', 'teamName', 'submissionDate', 'score'
|
|
348
|
+
]
|
|
349
|
+
dataset_fields = [
|
|
350
|
+
'ref', 'title', 'totalBytes', 'lastUpdated', 'downloadCount', 'voteCount',
|
|
351
|
+
'usabilityRating'
|
|
352
|
+
]
|
|
353
|
+
dataset_labels = [
|
|
354
|
+
'ref', 'title', 'size', 'lastUpdated', 'downloadCount', 'voteCount',
|
|
355
|
+
'usabilityRating'
|
|
356
|
+
]
|
|
357
|
+
dataset_file_fields = ['name', 'total_bytes', 'creationDate']
|
|
358
|
+
model_fields = ['id', 'ref', 'title', 'subtitle', 'author']
|
|
359
|
+
model_all_fields = [
|
|
360
|
+
'id', 'ref', 'author', 'slug', 'title', 'subtitle', 'isPrivate',
|
|
361
|
+
'description', 'publishTime'
|
|
362
|
+
]
|
|
363
|
+
model_file_fields = ['name', 'size', 'creationDate']
|
|
364
|
+
|
|
365
|
+
def _is_retriable(self, e):
|
|
366
|
+
return issubclass(type(e), ConnectionError) or \
|
|
367
|
+
issubclass(type(e), urllib3_exceptions.ConnectionError) or \
|
|
368
|
+
issubclass(type(e), urllib3_exceptions.ConnectTimeoutError) or \
|
|
369
|
+
issubclass(type(e), urllib3_exceptions.ProtocolError) or \
|
|
370
|
+
issubclass(type(e), requests.exceptions.ConnectionError) or \
|
|
371
|
+
issubclass(type(e), requests.exceptions.ConnectTimeout)
|
|
372
|
+
|
|
373
|
+
def _calculate_backoff_delay(self, attempt, initial_delay_millis,
|
|
374
|
+
retry_multiplier, randomness_factor):
|
|
375
|
+
delay_ms = initial_delay_millis * (retry_multiplier**attempt)
|
|
376
|
+
random_wait_ms = int(random() - 0.5) * 2 * delay_ms * randomness_factor
|
|
377
|
+
total_delay = (delay_ms + random_wait_ms) / 1000.0
|
|
378
|
+
return total_delay
|
|
379
|
+
|
|
380
|
+
def with_retry(self,
|
|
381
|
+
func,
|
|
382
|
+
max_retries=10,
|
|
383
|
+
initial_delay_millis=500,
|
|
384
|
+
retry_multiplier=1.7,
|
|
385
|
+
randomness_factor=0.5):
|
|
386
|
+
|
|
387
|
+
def retriable_func(*args):
|
|
388
|
+
for i in range(1, max_retries + 1):
|
|
389
|
+
try:
|
|
390
|
+
return func(*args)
|
|
391
|
+
except Exception as e:
|
|
392
|
+
if self._is_retriable(e) and i < max_retries:
|
|
393
|
+
total_delay = self._calculate_backoff_delay(i, initial_delay_millis,
|
|
394
|
+
retry_multiplier,
|
|
395
|
+
randomness_factor)
|
|
396
|
+
print('Request failed: %s. Will retry in %2.1f seconds' %
|
|
397
|
+
(e, total_delay))
|
|
398
|
+
time.sleep(total_delay)
|
|
399
|
+
continue
|
|
400
|
+
raise
|
|
401
|
+
|
|
402
|
+
return retriable_func
|
|
403
|
+
|
|
404
|
+
## Authentication
|
|
405
|
+
|
|
406
|
+
def authenticate(self):
|
|
407
|
+
"""authenticate the user with the Kaggle API. This method will generate
|
|
408
|
+
a configuration, first checking the environment for credential
|
|
409
|
+
variables, and falling back to looking for the .kaggle/kaggle.json
|
|
410
|
+
configuration file.
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
config_data = {}
|
|
414
|
+
# Ex: 'datasets list', 'competitions files', 'models instances get', etc.
|
|
415
|
+
api_command = ' '.join(sys.argv[1:])
|
|
416
|
+
|
|
417
|
+
# Step 1: try getting username/password from environment
|
|
418
|
+
config_data = self.read_config_environment(config_data)
|
|
419
|
+
|
|
420
|
+
# Step 2: if credentials were not in env read in configuration file
|
|
421
|
+
if self.CONFIG_NAME_USER not in config_data \
|
|
422
|
+
or self.CONFIG_NAME_KEY not in config_data:
|
|
423
|
+
if os.path.exists(self.config):
|
|
424
|
+
config_data = self.read_config_file(config_data)
|
|
425
|
+
elif self._is_help_or_version_command(api_command) or (len(
|
|
426
|
+
sys.argv) > 2 and api_command.startswith(
|
|
427
|
+
self.command_prefixes_allowing_anonymous_access)):
|
|
428
|
+
# Some API commands should be allowed without authentication.
|
|
429
|
+
return
|
|
430
|
+
else:
|
|
431
|
+
raise IOError('Could not find {}. Make sure it\'s located in'
|
|
432
|
+
' {}. Or use the environment method. See setup'
|
|
433
|
+
' instructions at'
|
|
434
|
+
' https://github.com/Kaggle/kaggle-api/'.format(
|
|
435
|
+
self.config_file, self.config_dir))
|
|
436
|
+
|
|
437
|
+
# Step 3: load into configuration!
|
|
438
|
+
self._load_config(config_data)
|
|
439
|
+
|
|
440
|
+
def _is_help_or_version_command(self, api_command):
|
|
441
|
+
"""determines if the string command passed in is for a help or version
|
|
442
|
+
command.
|
|
443
|
+
Parameters
|
|
444
|
+
==========
|
|
445
|
+
api_command: a string, 'datasets list', 'competitions files',
|
|
446
|
+
'models instances get', etc.
|
|
447
|
+
"""
|
|
448
|
+
return api_command.endswith(('-h', '--help', '-v', '--version'))
|
|
449
|
+
|
|
450
|
+
def read_config_environment(self, config_data=None, quiet=False):
|
|
451
|
+
"""read_config_environment is the second effort to get a username
|
|
452
|
+
and key to authenticate to the Kaggle API. The environment keys
|
|
453
|
+
are equivalent to the kaggle.json file, but with "KAGGLE_" prefix
|
|
454
|
+
to define a unique namespace.
|
|
455
|
+
|
|
456
|
+
Parameters
|
|
457
|
+
==========
|
|
458
|
+
config_data: a partially loaded configuration dictionary (optional)
|
|
459
|
+
quiet: suppress verbose print of output (default is False)
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
# Add all variables that start with KAGGLE_ to config data
|
|
463
|
+
|
|
464
|
+
if config_data is None:
|
|
465
|
+
config_data = {}
|
|
466
|
+
for key, val in os.environ.items():
|
|
467
|
+
if key.startswith('KAGGLE_'):
|
|
468
|
+
config_key = key.replace('KAGGLE_', '', 1).lower()
|
|
469
|
+
config_data[config_key] = val
|
|
470
|
+
|
|
471
|
+
return config_data
|
|
472
|
+
|
|
473
|
+
## Configuration
|
|
474
|
+
|
|
475
|
+
def _load_config(self, config_data):
|
|
476
|
+
"""the final step of the authenticate steps, where we load the values
|
|
477
|
+
from config_data into the Configuration object.
|
|
478
|
+
|
|
479
|
+
Parameters
|
|
480
|
+
==========
|
|
481
|
+
config_data: a dictionary with configuration values (keys) to read
|
|
482
|
+
into self.config_values
|
|
483
|
+
|
|
484
|
+
"""
|
|
485
|
+
# Username and password are required.
|
|
486
|
+
|
|
487
|
+
for item in [self.CONFIG_NAME_USER, self.CONFIG_NAME_KEY]:
|
|
488
|
+
if item not in config_data:
|
|
489
|
+
raise ValueError('Error: Missing %s in configuration.' % item)
|
|
490
|
+
|
|
491
|
+
configuration = Configuration()
|
|
492
|
+
|
|
493
|
+
# Add to the final configuration (required)
|
|
494
|
+
|
|
495
|
+
configuration.username = config_data[self.CONFIG_NAME_USER]
|
|
496
|
+
configuration.password = config_data[self.CONFIG_NAME_KEY]
|
|
497
|
+
|
|
498
|
+
# Proxy
|
|
499
|
+
|
|
500
|
+
if self.CONFIG_NAME_PROXY in config_data:
|
|
501
|
+
configuration.proxy = config_data[self.CONFIG_NAME_PROXY]
|
|
502
|
+
|
|
503
|
+
# Cert File
|
|
504
|
+
|
|
505
|
+
if self.CONFIG_NAME_SSL_CA_CERT in config_data:
|
|
506
|
+
configuration.ssl_ca_cert = config_data[self.CONFIG_NAME_SSL_CA_CERT]
|
|
507
|
+
|
|
508
|
+
# Keep config values with class instance, and load api client!
|
|
509
|
+
|
|
510
|
+
self.config_values = config_data
|
|
511
|
+
|
|
512
|
+
def read_config_file(self, config_data=None, quiet=False):
|
|
513
|
+
"""read_config_file is the first effort to get a username
|
|
514
|
+
and key to authenticate to the Kaggle API. Since we can get the
|
|
515
|
+
username and password from the environment, it's not required.
|
|
516
|
+
|
|
517
|
+
Parameters
|
|
518
|
+
==========
|
|
519
|
+
config_data: the Configuration object to save a username and
|
|
520
|
+
password, if defined
|
|
521
|
+
quiet: suppress verbose print of output (default is False)
|
|
522
|
+
"""
|
|
523
|
+
if config_data is None:
|
|
524
|
+
config_data = {}
|
|
525
|
+
|
|
526
|
+
if os.path.exists(self.config):
|
|
527
|
+
|
|
528
|
+
try:
|
|
529
|
+
if os.name != 'nt':
|
|
530
|
+
permissions = os.stat(self.config).st_mode
|
|
531
|
+
if (permissions & 4) or (permissions & 32):
|
|
532
|
+
print('Warning: Your Kaggle API key is readable by other '
|
|
533
|
+
'users on this system! To fix this, you can run ' +
|
|
534
|
+
'\'chmod 600 {}\''.format(self.config))
|
|
535
|
+
|
|
536
|
+
with open(self.config) as f:
|
|
537
|
+
config_data = json.load(f)
|
|
538
|
+
except:
|
|
539
|
+
pass
|
|
540
|
+
|
|
541
|
+
else:
|
|
542
|
+
|
|
543
|
+
# Warn the user that configuration will be reliant on environment
|
|
544
|
+
if not quiet:
|
|
545
|
+
print('No Kaggle API config file found, will use environment.')
|
|
546
|
+
|
|
547
|
+
return config_data
|
|
548
|
+
|
|
549
|
+
def _read_config_file(self):
|
|
550
|
+
"""read in the configuration file, a json file defined at self.config"""
|
|
551
|
+
|
|
552
|
+
try:
|
|
553
|
+
with open(self.config, 'r') as f:
|
|
554
|
+
config_data = json.load(f)
|
|
555
|
+
except FileNotFoundError:
|
|
556
|
+
config_data = {}
|
|
557
|
+
|
|
558
|
+
return config_data
|
|
559
|
+
|
|
560
|
+
def _write_config_file(self, config_data, indent=2):
|
|
561
|
+
"""write config data to file.
|
|
562
|
+
|
|
563
|
+
Parameters
|
|
564
|
+
==========
|
|
565
|
+
config_data: the Configuration object to save a username and
|
|
566
|
+
password, if defined
|
|
567
|
+
indent: number of tab indentations to use when writing json
|
|
568
|
+
"""
|
|
569
|
+
with open(self.config, 'w') as f:
|
|
570
|
+
json.dump(config_data, f, indent=indent)
|
|
571
|
+
|
|
572
|
+
def set_config_value(self, name, value, quiet=False):
|
|
573
|
+
"""a client helper function to set a configuration value, meaning
|
|
574
|
+
reading in the configuration file (if it exists), saving a new
|
|
575
|
+
config value, and then writing back
|
|
576
|
+
|
|
577
|
+
Parameters
|
|
578
|
+
==========
|
|
579
|
+
name: the name of the value to set (key in dictionary)
|
|
580
|
+
value: the value to set at the key
|
|
581
|
+
quiet: disable verbose output if True (default is False)
|
|
582
|
+
"""
|
|
583
|
+
|
|
584
|
+
config_data = self._read_config_file()
|
|
585
|
+
|
|
586
|
+
if value is not None:
|
|
587
|
+
|
|
588
|
+
# Update the config file with the value
|
|
589
|
+
config_data[name] = value
|
|
590
|
+
|
|
591
|
+
# Update the instance with the value
|
|
592
|
+
self.config_values[name] = value
|
|
593
|
+
|
|
594
|
+
# If defined by client, set and save!
|
|
595
|
+
self._write_config_file(config_data)
|
|
596
|
+
|
|
597
|
+
if not quiet:
|
|
598
|
+
self.print_config_value(name, separator=' is now set to: ')
|
|
599
|
+
|
|
600
|
+
def unset_config_value(self, name, quiet=False):
|
|
601
|
+
"""unset a configuration value
|
|
602
|
+
Parameters
|
|
603
|
+
==========
|
|
604
|
+
name: the name of the value to unset (remove key in dictionary)
|
|
605
|
+
quiet: disable verbose output if True (default is False)
|
|
606
|
+
"""
|
|
607
|
+
|
|
608
|
+
config_data = self._read_config_file()
|
|
609
|
+
|
|
610
|
+
if name in config_data:
|
|
611
|
+
|
|
612
|
+
del config_data[name]
|
|
613
|
+
|
|
614
|
+
self._write_config_file(config_data)
|
|
615
|
+
|
|
616
|
+
if not quiet:
|
|
617
|
+
self.print_config_value(name, separator=' is now set to: ')
|
|
618
|
+
|
|
619
|
+
def get_config_value(self, name):
|
|
620
|
+
""" return a config value (with key name) if it's in the config_values,
|
|
621
|
+
otherwise return None
|
|
622
|
+
|
|
623
|
+
Parameters
|
|
624
|
+
==========
|
|
625
|
+
name: the config value key to get
|
|
626
|
+
|
|
627
|
+
"""
|
|
628
|
+
if name in self.config_values:
|
|
629
|
+
return self.config_values[name]
|
|
630
|
+
|
|
631
|
+
def get_default_download_dir(self, *subdirs):
|
|
632
|
+
""" Get the download path for a file. If not defined, return default
|
|
633
|
+
from config.
|
|
634
|
+
|
|
635
|
+
Parameters
|
|
636
|
+
==========
|
|
637
|
+
subdirs: a single (or list of) subfolders under the basepath
|
|
638
|
+
"""
|
|
639
|
+
# Look up value for key "path" in the config
|
|
640
|
+
path = self.get_config_value(self.CONFIG_NAME_PATH)
|
|
641
|
+
|
|
642
|
+
# If not set in config, default to present working directory
|
|
643
|
+
if path is None:
|
|
644
|
+
return os.getcwd()
|
|
645
|
+
|
|
646
|
+
return os.path.join(path, *subdirs)
|
|
647
|
+
|
|
648
|
+
def print_config_value(self, name, prefix='- ', separator=': '):
|
|
649
|
+
"""print a single configuration value, based on a prefix and separator
|
|
650
|
+
|
|
651
|
+
Parameters
|
|
652
|
+
==========
|
|
653
|
+
name: the key of the config valur in self.config_values to print
|
|
654
|
+
prefix: the prefix to print
|
|
655
|
+
separator: the separator to use (default is : )
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
value_out = 'None'
|
|
659
|
+
if name in self.config_values and self.config_values[name] is not None:
|
|
660
|
+
value_out = self.config_values[name]
|
|
661
|
+
print(prefix + name + separator + value_out)
|
|
662
|
+
|
|
663
|
+
def print_config_values(self, prefix='- '):
|
|
664
|
+
"""a wrapper to print_config_value to print all configuration values
|
|
665
|
+
Parameters
|
|
666
|
+
==========
|
|
667
|
+
prefix: the character prefix to put before the printed config value
|
|
668
|
+
defaults to "- "
|
|
669
|
+
"""
|
|
670
|
+
print('Configuration values from ' + self.config_dir)
|
|
671
|
+
self.print_config_value(self.CONFIG_NAME_USER, prefix=prefix)
|
|
672
|
+
self.print_config_value(self.CONFIG_NAME_PATH, prefix=prefix)
|
|
673
|
+
self.print_config_value(self.CONFIG_NAME_PROXY, prefix=prefix)
|
|
674
|
+
self.print_config_value(self.CONFIG_NAME_COMPETITION, prefix=prefix)
|
|
675
|
+
|
|
676
|
+
def build_kaggle_client(self):
|
|
677
|
+
env = KaggleEnv.STAGING if '--staging' in self.args \
|
|
678
|
+
else KaggleEnv.ADMIN if '--admin' in self.args \
|
|
679
|
+
else KaggleEnv.LOCAL if '--local' in self.args \
|
|
680
|
+
else KaggleEnv.PROD
|
|
681
|
+
verbose = '--verbose' in self.args or '-v' in self.args
|
|
682
|
+
# config = self.api_client.configuration
|
|
683
|
+
return KaggleClient(
|
|
684
|
+
env=env,
|
|
685
|
+
verbose=verbose,
|
|
686
|
+
username=self.config_values['username'],
|
|
687
|
+
password=self.config_values['key'])
|
|
688
|
+
|
|
689
|
+
def camel_to_snake(self, name):
|
|
690
|
+
"""
|
|
691
|
+
:param name: field in camel case
|
|
692
|
+
:return: field in snake case
|
|
693
|
+
"""
|
|
694
|
+
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
|
695
|
+
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
|
|
696
|
+
|
|
697
|
+
def lookup_enum(self, enum_class, item_name):
|
|
698
|
+
item = self.camel_to_snake(item_name).upper()
|
|
699
|
+
try:
|
|
700
|
+
return enum_class[item]
|
|
701
|
+
except KeyError:
|
|
702
|
+
prefix = self.camel_to_snake(enum_class.__name__).upper()
|
|
703
|
+
return enum_class[f'{prefix}_{self.camel_to_snake(item_name).upper()}']
|
|
704
|
+
|
|
705
|
+
def short_enum_name(self, value):
|
|
706
|
+
full_name = str(value)
|
|
707
|
+
names = full_name.split('.')
|
|
708
|
+
prefix_len = len(self.camel_to_snake(names[0])) + 1 # underscore
|
|
709
|
+
return names[1][prefix_len:].lower()
|
|
710
|
+
|
|
711
|
+
## Competitions
|
|
712
|
+
|
|
713
|
+
def competitions_list(self,
|
|
714
|
+
group=None,
|
|
715
|
+
category=None,
|
|
716
|
+
sort_by=None,
|
|
717
|
+
page=1,
|
|
718
|
+
search=None):
|
|
719
|
+
""" Make a call to list competitions, format the response, and return
|
|
720
|
+
a list of ApiCompetition instances
|
|
721
|
+
|
|
722
|
+
Parameters
|
|
723
|
+
==========
|
|
724
|
+
|
|
725
|
+
page: the page to return (default is 1)
|
|
726
|
+
search: a search term to use (default is empty string)
|
|
727
|
+
sort_by: how to sort the result, see valid_competition_sort_by for options
|
|
728
|
+
category: category to filter result to; use 'all' to get closed competitions
|
|
729
|
+
group: group to filter result to
|
|
730
|
+
"""
|
|
731
|
+
if group:
|
|
732
|
+
if group not in self.valid_competition_groups:
|
|
733
|
+
raise ValueError('Invalid group specified. Valid options are ' +
|
|
734
|
+
str(self.valid_competition_groups))
|
|
735
|
+
if group == 'all':
|
|
736
|
+
group = CompetitionListTab.COMPETITION_LIST_TAB_EVERYTHING
|
|
737
|
+
else:
|
|
738
|
+
group = self.lookup_enum(CompetitionListTab, group)
|
|
739
|
+
|
|
740
|
+
if category:
|
|
741
|
+
if category not in self.valid_competition_categories:
|
|
742
|
+
raise ValueError('Invalid category specified. Valid options are ' +
|
|
743
|
+
str(self.valid_competition_categories))
|
|
744
|
+
category = self.lookup_enum(HostSegment, category)
|
|
745
|
+
|
|
746
|
+
if sort_by:
|
|
747
|
+
if sort_by not in self.valid_competition_sort_by:
|
|
748
|
+
raise ValueError('Invalid sort_by specified. Valid options are ' +
|
|
749
|
+
str(self.valid_competition_sort_by))
|
|
750
|
+
sort_by = self.lookup_enum(CompetitionSortBy, sort_by)
|
|
751
|
+
|
|
752
|
+
with self.build_kaggle_client() as kaggle:
|
|
753
|
+
request = ApiListCompetitionsRequest()
|
|
754
|
+
request.group = group
|
|
755
|
+
request.page = page
|
|
756
|
+
request.category = category
|
|
757
|
+
request.search = search
|
|
758
|
+
request.sort_by = sort_by
|
|
759
|
+
response = kaggle.competitions.competition_api_client.list_competitions(
|
|
760
|
+
request)
|
|
761
|
+
return response.competitions
|
|
762
|
+
|
|
763
|
+
def competitions_list_cli(self,
|
|
764
|
+
group=None,
|
|
765
|
+
category=None,
|
|
766
|
+
sort_by=None,
|
|
767
|
+
page=1,
|
|
768
|
+
search=None,
|
|
769
|
+
csv_display=False):
|
|
770
|
+
""" A wrapper for competitions_list for the client.
|
|
771
|
+
|
|
772
|
+
Parameters
|
|
773
|
+
==========
|
|
774
|
+
group: group to filter result to
|
|
775
|
+
category: category to filter result to
|
|
776
|
+
sort_by: how to sort the result, see valid_sort_by for options
|
|
777
|
+
page: the page to return (default is 1)
|
|
778
|
+
search: a search term to use (default is empty string)
|
|
779
|
+
csv_display: if True, print comma separated values
|
|
780
|
+
"""
|
|
781
|
+
competitions = self.competitions_list(
|
|
782
|
+
group=group,
|
|
783
|
+
category=category,
|
|
784
|
+
sort_by=sort_by,
|
|
785
|
+
page=page,
|
|
786
|
+
search=search)
|
|
787
|
+
if competitions:
|
|
788
|
+
if csv_display:
|
|
789
|
+
self.print_csv(competitions, self.competition_fields)
|
|
790
|
+
else:
|
|
791
|
+
self.print_table(competitions, self.competition_fields)
|
|
792
|
+
else:
|
|
793
|
+
print('No competitions found')
|
|
794
|
+
|
|
795
|
+
def competition_submit(self, file_name, message, competition, quiet=False):
|
|
796
|
+
""" Submit a competition.
|
|
797
|
+
|
|
798
|
+
Parameters
|
|
799
|
+
==========
|
|
800
|
+
file_name: the competition metadata file
|
|
801
|
+
message: the submission description
|
|
802
|
+
competition: the competition name; if not given use the 'competition' config value
|
|
803
|
+
quiet: suppress verbose output (default is False)
|
|
804
|
+
"""
|
|
805
|
+
if competition is None:
|
|
806
|
+
competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
|
|
807
|
+
if competition is not None and not quiet:
|
|
808
|
+
print('Using competition: ' + competition)
|
|
809
|
+
|
|
810
|
+
if competition is None:
|
|
811
|
+
raise ValueError('No competition specified')
|
|
812
|
+
else:
|
|
813
|
+
with self.build_kaggle_client() as kaggle:
|
|
814
|
+
request = ApiStartSubmissionUploadRequest()
|
|
815
|
+
request.competition_name = competition
|
|
816
|
+
request.file_name = os.path.basename(file_name)
|
|
817
|
+
request.content_length = os.path.getsize(file_name)
|
|
818
|
+
request.last_modified_epoch_seconds = int(os.path.getmtime(file_name))
|
|
819
|
+
response = kaggle.competitions.competition_api_client.start_submission_upload(
|
|
820
|
+
request)
|
|
821
|
+
upload_status = self.upload_complete(file_name, response.create_url,
|
|
822
|
+
quiet)
|
|
823
|
+
if upload_status != ResumableUploadResult.COMPLETE:
|
|
824
|
+
# Actual error is printed during upload_complete. Not
|
|
825
|
+
# ideal but changing would not be backwards compatible
|
|
826
|
+
return "Could not submit to competition"
|
|
827
|
+
|
|
828
|
+
submit_request = ApiCreateSubmissionRequest()
|
|
829
|
+
submit_request.competition_name = competition
|
|
830
|
+
submit_request.blob_file_tokens = response.token
|
|
831
|
+
submit_request.submission_description = message
|
|
832
|
+
submit_response = kaggle.competitions.competition_api_client.create_submission(
|
|
833
|
+
submit_request)
|
|
834
|
+
return submit_response
|
|
835
|
+
|
|
836
|
+
def competition_submit_cli(self,
|
|
837
|
+
file_name,
|
|
838
|
+
message,
|
|
839
|
+
competition,
|
|
840
|
+
competition_opt=None,
|
|
841
|
+
quiet=False):
|
|
842
|
+
""" Submit a competition using the client. Arguments are same as for
|
|
843
|
+
competition_submit, except for extra arguments provided here.
|
|
844
|
+
|
|
845
|
+
Parameters
|
|
846
|
+
==========
|
|
847
|
+
file_name: the competition metadata file
|
|
848
|
+
message: the submission description
|
|
849
|
+
competition: the competition name; if not given use the 'competition' config value
|
|
850
|
+
quiet: suppress verbose output (default is False)
|
|
851
|
+
competition_opt: an alternative competition option provided by cli
|
|
852
|
+
"""
|
|
853
|
+
competition = competition or competition_opt
|
|
854
|
+
try:
|
|
855
|
+
submit_result = self.competition_submit(file_name, message, competition,
|
|
856
|
+
quiet)
|
|
857
|
+
except RequestException as e:
|
|
858
|
+
if e.response and e.response.status_code == 404:
|
|
859
|
+
print('Could not find competition - please verify that you '
|
|
860
|
+
'entered the correct competition ID and that the '
|
|
861
|
+
'competition is still accepting submissions.')
|
|
862
|
+
return None
|
|
863
|
+
else:
|
|
864
|
+
raise e
|
|
865
|
+
return submit_result.message
|
|
866
|
+
|
|
867
|
+
def competition_submissions(self,
|
|
868
|
+
competition,
|
|
869
|
+
group=None,
|
|
870
|
+
sort=None,
|
|
871
|
+
page_token=0,
|
|
872
|
+
page_size=20):
|
|
873
|
+
""" Get the list of Submission for a particular competition.
|
|
874
|
+
|
|
875
|
+
Parameters
|
|
876
|
+
==========
|
|
877
|
+
competition: the name of the competition
|
|
878
|
+
group: the submission group
|
|
879
|
+
sort: the sort-by option
|
|
880
|
+
page_token: token for pagination
|
|
881
|
+
page_size: the number of items per page
|
|
882
|
+
"""
|
|
883
|
+
with self.build_kaggle_client() as kaggle:
|
|
884
|
+
request = ApiListSubmissionsRequest()
|
|
885
|
+
request.competition_name = competition
|
|
886
|
+
request.page = page_token
|
|
887
|
+
request.group = group
|
|
888
|
+
request.sort_by = sort
|
|
889
|
+
response = kaggle.competitions.competition_api_client.list_submissions(
|
|
890
|
+
request)
|
|
891
|
+
return response.submissions
|
|
892
|
+
|
|
893
|
+
def competition_submissions_cli(self,
|
|
894
|
+
competition=None,
|
|
895
|
+
competition_opt=None,
|
|
896
|
+
csv_display=False,
|
|
897
|
+
page_token=None,
|
|
898
|
+
page_size=20,
|
|
899
|
+
quiet=False):
|
|
900
|
+
""" A wrapper to competition_submission, will return either json or csv
|
|
901
|
+
to the user. Additional parameters are listed below, see
|
|
902
|
+
competition_submissions for rest.
|
|
903
|
+
|
|
904
|
+
Parameters
|
|
905
|
+
==========
|
|
906
|
+
competition: the name of the competition. If None, look to config
|
|
907
|
+
competition_opt: an alternative competition option provided by cli
|
|
908
|
+
csv_display: if True, print comma separated values
|
|
909
|
+
page_token: token for pagination
|
|
910
|
+
page_size: the number of items per page
|
|
911
|
+
quiet: suppress verbose output (default is False)
|
|
912
|
+
"""
|
|
913
|
+
competition = competition or competition_opt
|
|
914
|
+
if competition is None:
|
|
915
|
+
competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
|
|
916
|
+
if competition is not None and not quiet:
|
|
917
|
+
print('Using competition: ' + competition)
|
|
918
|
+
|
|
919
|
+
if competition is None:
|
|
920
|
+
raise ValueError('No competition specified')
|
|
921
|
+
else:
|
|
922
|
+
submissions = self.competition_submissions(
|
|
923
|
+
competition, page_token=page_token, page_size=page_size)
|
|
924
|
+
if submissions:
|
|
925
|
+
if csv_display:
|
|
926
|
+
self.print_csv(submissions, self.submission_fields)
|
|
927
|
+
else:
|
|
928
|
+
self.print_table(submissions, self.submission_fields)
|
|
929
|
+
else:
|
|
930
|
+
print('No submissions found')
|
|
931
|
+
|
|
932
|
+
def competition_list_files(self, competition, page_token=None, page_size=20):
|
|
933
|
+
""" List files for a competition.
|
|
934
|
+
Parameters
|
|
935
|
+
==========
|
|
936
|
+
competition: the name of the competition
|
|
937
|
+
page_token: the page token for pagination
|
|
938
|
+
page_size: the number of items per page
|
|
939
|
+
"""
|
|
940
|
+
with self.build_kaggle_client() as kaggle:
|
|
941
|
+
request = ApiListDataFilesRequest()
|
|
942
|
+
request.competition_name = competition
|
|
943
|
+
request.page_token = page_token
|
|
944
|
+
request.page_size = page_size
|
|
945
|
+
response = kaggle.competitions.competition_api_client.list_data_files(
|
|
946
|
+
request)
|
|
947
|
+
return response
|
|
948
|
+
|
|
949
|
+
def competition_list_files_cli(self,
|
|
950
|
+
competition,
|
|
951
|
+
competition_opt=None,
|
|
952
|
+
csv_display=False,
|
|
953
|
+
page_token=None,
|
|
954
|
+
page_size=20,
|
|
955
|
+
quiet=False):
|
|
956
|
+
""" List files for a competition, if it exists.
|
|
957
|
+
|
|
958
|
+
Parameters
|
|
959
|
+
==========
|
|
960
|
+
competition: the name of the competition. If None, look to config
|
|
961
|
+
competition_opt: an alternative competition option provided by cli
|
|
962
|
+
csv_display: if True, print comma separated values
|
|
963
|
+
page_token: the page token for pagination
|
|
964
|
+
page_size: the number of items per page
|
|
965
|
+
quiet: suppress verbose output (default is False)
|
|
966
|
+
"""
|
|
967
|
+
competition = competition or competition_opt
|
|
968
|
+
if competition is None:
|
|
969
|
+
competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
|
|
970
|
+
if competition is not None and not quiet:
|
|
971
|
+
print('Using competition: ' + competition)
|
|
972
|
+
|
|
973
|
+
if competition is None:
|
|
974
|
+
raise ValueError('No competition specified')
|
|
975
|
+
else:
|
|
976
|
+
result = self.competition_list_files(competition, page_token, page_size)
|
|
977
|
+
next_page_token = result.next_page_token
|
|
978
|
+
if next_page_token:
|
|
979
|
+
print('Next Page Token = {}'.format(next_page_token))
|
|
980
|
+
if result:
|
|
981
|
+
if csv_display:
|
|
982
|
+
self.print_csv(result.files, self.competition_file_fields,
|
|
983
|
+
self.competition_file_labels)
|
|
984
|
+
else:
|
|
985
|
+
self.print_table(result.files, self.competition_file_fields,
|
|
986
|
+
self.competition_file_labels)
|
|
987
|
+
else:
|
|
988
|
+
print('No files found')
|
|
989
|
+
|
|
990
|
+
def competition_download_file(self,
|
|
991
|
+
competition,
|
|
992
|
+
file_name,
|
|
993
|
+
path=None,
|
|
994
|
+
force=False,
|
|
995
|
+
quiet=False):
|
|
996
|
+
""" Download a competition file to a designated location, or use
|
|
997
|
+
a default location.
|
|
998
|
+
|
|
999
|
+
Parameters
|
|
1000
|
+
=========
|
|
1001
|
+
competition: the name of the competition
|
|
1002
|
+
file_name: the configuration file name
|
|
1003
|
+
path: a path to download the file to
|
|
1004
|
+
force: force the download if the file already exists (default False)
|
|
1005
|
+
quiet: suppress verbose output (default is False)
|
|
1006
|
+
"""
|
|
1007
|
+
if path is None:
|
|
1008
|
+
effective_path = self.get_default_download_dir('competitions',
|
|
1009
|
+
competition)
|
|
1010
|
+
else:
|
|
1011
|
+
effective_path = path
|
|
1012
|
+
|
|
1013
|
+
with self.build_kaggle_client() as kaggle:
|
|
1014
|
+
request = ApiDownloadDataFileRequest()
|
|
1015
|
+
request.competition_name = competition
|
|
1016
|
+
request.file_name = file_name
|
|
1017
|
+
response = kaggle.competitions.competition_api_client.download_data_file(
|
|
1018
|
+
request)
|
|
1019
|
+
url = response.history[0].url
|
|
1020
|
+
outfile = os.path.join(effective_path, url.split('?')[0].split('/')[-1])
|
|
1021
|
+
|
|
1022
|
+
if force or self.download_needed(response, outfile, quiet):
|
|
1023
|
+
self.download_file(response, outfile, kaggle.http_client(), quiet,
|
|
1024
|
+
not force)
|
|
1025
|
+
|
|
1026
|
+
def competition_download_files(self,
|
|
1027
|
+
competition,
|
|
1028
|
+
path=None,
|
|
1029
|
+
force=False,
|
|
1030
|
+
quiet=True):
|
|
1031
|
+
""" Download all competition files.
|
|
1032
|
+
|
|
1033
|
+
Parameters
|
|
1034
|
+
=========
|
|
1035
|
+
competition: the name of the competition
|
|
1036
|
+
path: a path to download the file to
|
|
1037
|
+
force: force the download if the file already exists (default False)
|
|
1038
|
+
quiet: suppress verbose output (default is True)
|
|
1039
|
+
"""
|
|
1040
|
+
if path is None:
|
|
1041
|
+
effective_path = self.get_default_download_dir('competitions',
|
|
1042
|
+
competition)
|
|
1043
|
+
else:
|
|
1044
|
+
effective_path = path
|
|
1045
|
+
|
|
1046
|
+
with self.build_kaggle_client() as kaggle:
|
|
1047
|
+
request = ApiDownloadDataFilesRequest()
|
|
1048
|
+
request.competition_name = competition
|
|
1049
|
+
response = kaggle.competitions.competition_api_client.download_data_files(
|
|
1050
|
+
request)
|
|
1051
|
+
url = response.url.split('?')[0]
|
|
1052
|
+
outfile = os.path.join(effective_path,
|
|
1053
|
+
competition + '.' + url.split('.')[-1])
|
|
1054
|
+
|
|
1055
|
+
if force or self.download_needed(response, outfile, quiet):
|
|
1056
|
+
self.download_file(response, outfile, quiet, not force)
|
|
1057
|
+
|
|
1058
|
+
def competition_download_cli(self,
|
|
1059
|
+
competition,
|
|
1060
|
+
competition_opt=None,
|
|
1061
|
+
file_name=None,
|
|
1062
|
+
path=None,
|
|
1063
|
+
force=False,
|
|
1064
|
+
quiet=False):
|
|
1065
|
+
""" A wrapper to competition_download_files, but first will parse input
|
|
1066
|
+
from API client. Additional parameters are listed here, see
|
|
1067
|
+
competition_download for remaining.
|
|
1068
|
+
|
|
1069
|
+
Parameters
|
|
1070
|
+
=========
|
|
1071
|
+
competition: the name of the competition
|
|
1072
|
+
competition_opt: an alternative competition option provided by cli
|
|
1073
|
+
file_name: the configuration file name
|
|
1074
|
+
path: a path to download the file to
|
|
1075
|
+
force: force the download if the file already exists (default False)
|
|
1076
|
+
quiet: suppress verbose output (default is False)
|
|
1077
|
+
"""
|
|
1078
|
+
competition = competition or competition_opt
|
|
1079
|
+
if competition is None:
|
|
1080
|
+
competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
|
|
1081
|
+
if competition is not None and not quiet:
|
|
1082
|
+
print('Using competition: ' + competition)
|
|
1083
|
+
|
|
1084
|
+
if competition is None:
|
|
1085
|
+
raise ValueError('No competition specified')
|
|
1086
|
+
else:
|
|
1087
|
+
if file_name is None:
|
|
1088
|
+
self.competition_download_files(competition, path, force, quiet)
|
|
1089
|
+
else:
|
|
1090
|
+
self.competition_download_file(competition, file_name, path, force,
|
|
1091
|
+
quiet)
|
|
1092
|
+
|
|
1093
|
+
def competition_leaderboard_download(self, competition, path, quiet=True):
|
|
1094
|
+
""" Download a competition leaderboard.
|
|
1095
|
+
|
|
1096
|
+
Parameters
|
|
1097
|
+
=========
|
|
1098
|
+
competition: the name of the competition
|
|
1099
|
+
path: a path to download the file to
|
|
1100
|
+
quiet: suppress verbose output (default is True)
|
|
1101
|
+
"""
|
|
1102
|
+
with self.build_kaggle_client() as kaggle:
|
|
1103
|
+
request = ApiDownloadLeaderboardRequest()
|
|
1104
|
+
request.competition_name = competition
|
|
1105
|
+
response = kaggle.competitions.competition_api_client.download_leaderboard(
|
|
1106
|
+
request)
|
|
1107
|
+
if path is None:
|
|
1108
|
+
effective_path = self.get_default_download_dir('competitions',
|
|
1109
|
+
competition)
|
|
1110
|
+
else:
|
|
1111
|
+
effective_path = path
|
|
1112
|
+
|
|
1113
|
+
file_name = competition + '.zip'
|
|
1114
|
+
outfile = os.path.join(effective_path, file_name)
|
|
1115
|
+
self.download_file(response, outfile, quiet)
|
|
1116
|
+
|
|
1117
|
+
def competition_leaderboard_view(self, competition):
|
|
1118
|
+
""" View a leaderboard based on a competition name.
|
|
1119
|
+
|
|
1120
|
+
Parameters
|
|
1121
|
+
==========
|
|
1122
|
+
competition: the competition name to view leadboard for
|
|
1123
|
+
"""
|
|
1124
|
+
with self.build_kaggle_client() as kaggle:
|
|
1125
|
+
request = ApiGetLeaderboardRequest()
|
|
1126
|
+
request.competition_name = competition
|
|
1127
|
+
response = kaggle.competitions.competition_api_client.get_leaderboard(
|
|
1128
|
+
request)
|
|
1129
|
+
return response.submissions
|
|
1130
|
+
|
|
1131
|
+
def competition_leaderboard_cli(self,
|
|
1132
|
+
competition,
|
|
1133
|
+
competition_opt=None,
|
|
1134
|
+
path=None,
|
|
1135
|
+
view=False,
|
|
1136
|
+
download=False,
|
|
1137
|
+
csv_display=False,
|
|
1138
|
+
quiet=False):
|
|
1139
|
+
""" A wrapper for competition_leaderbord_view that will print the
|
|
1140
|
+
results as a table or comma separated values
|
|
1141
|
+
|
|
1142
|
+
Parameters
|
|
1143
|
+
==========
|
|
1144
|
+
competition: the competition name to view leadboard for
|
|
1145
|
+
competition_opt: an alternative competition option provided by cli
|
|
1146
|
+
path: a path to download to, if download is True
|
|
1147
|
+
view: if True, show the results in the terminal as csv or table
|
|
1148
|
+
download: if True, download the entire leaderboard
|
|
1149
|
+
csv_display: if True, print comma separated values instead of table
|
|
1150
|
+
quiet: suppress verbose output (default is False)
|
|
1151
|
+
"""
|
|
1152
|
+
competition = competition or competition_opt
|
|
1153
|
+
if not view and not download:
|
|
1154
|
+
raise ValueError('Either --show or --download must be specified')
|
|
1155
|
+
|
|
1156
|
+
if competition is None:
|
|
1157
|
+
competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
|
|
1158
|
+
if competition is not None and not quiet:
|
|
1159
|
+
print('Using competition: ' + competition)
|
|
1160
|
+
|
|
1161
|
+
if competition is None:
|
|
1162
|
+
raise ValueError('No competition specified')
|
|
1163
|
+
|
|
1164
|
+
if download:
|
|
1165
|
+
self.competition_leaderboard_download(competition, path, quiet)
|
|
1166
|
+
|
|
1167
|
+
if view:
|
|
1168
|
+
results = self.competition_leaderboard_view(competition)
|
|
1169
|
+
if results:
|
|
1170
|
+
if csv_display:
|
|
1171
|
+
self.print_csv(results, self.competition_leaderboard_fields)
|
|
1172
|
+
else:
|
|
1173
|
+
self.print_table(results, self.competition_leaderboard_fields)
|
|
1174
|
+
else:
|
|
1175
|
+
print('No results found')
|
|
1176
|
+
|
|
1177
|
+
def dataset_list(self,
|
|
1178
|
+
sort_by=None,
|
|
1179
|
+
size=None,
|
|
1180
|
+
file_type=None,
|
|
1181
|
+
license_name=None,
|
|
1182
|
+
tag_ids=None,
|
|
1183
|
+
search=None,
|
|
1184
|
+
user=None,
|
|
1185
|
+
mine=False,
|
|
1186
|
+
page=1,
|
|
1187
|
+
max_size=None,
|
|
1188
|
+
min_size=None):
|
|
1189
|
+
""" Return a list of datasets.
|
|
1190
|
+
|
|
1191
|
+
Parameters
|
|
1192
|
+
==========
|
|
1193
|
+
sort_by: how to sort the result, see valid_dataset_sort_bys for options
|
|
1194
|
+
size: Deprecated
|
|
1195
|
+
file_type: the format, see valid_dataset_file_types for string options
|
|
1196
|
+
license_name: string descriptor for license, see valid_dataset_license_names
|
|
1197
|
+
tag_ids: tag identifiers to filter the search
|
|
1198
|
+
search: a search term to use (default is empty string)
|
|
1199
|
+
user: username to filter the search to
|
|
1200
|
+
mine: boolean if True, group is changed to "my" to return personal
|
|
1201
|
+
page: the page to return (default is 1)
|
|
1202
|
+
max_size: the maximum size of the dataset to return (bytes)
|
|
1203
|
+
min_size: the minimum size of the dataset to return (bytes)
|
|
1204
|
+
"""
|
|
1205
|
+
if sort_by:
|
|
1206
|
+
if sort_by not in self.valid_dataset_sort_bys:
|
|
1207
|
+
raise ValueError('Invalid sort by specified. Valid options are ' +
|
|
1208
|
+
str(self.valid_dataset_sort_bys))
|
|
1209
|
+
else:
|
|
1210
|
+
sort_by = self.lookup_enum(DatasetSortBy, sort_by)
|
|
1211
|
+
|
|
1212
|
+
if size:
|
|
1213
|
+
raise ValueError(
|
|
1214
|
+
'The --size parameter has been deprecated. ' +
|
|
1215
|
+
'Please use --max-size and --min-size to filter dataset sizes.')
|
|
1216
|
+
|
|
1217
|
+
if file_type:
|
|
1218
|
+
if file_type not in self.valid_dataset_file_types:
|
|
1219
|
+
raise ValueError('Invalid file type specified. Valid options are ' +
|
|
1220
|
+
str(self.valid_dataset_file_types))
|
|
1221
|
+
else:
|
|
1222
|
+
file_type = self.lookup_enum(DatasetFileTypeGroup, file_type)
|
|
1223
|
+
|
|
1224
|
+
if license_name:
|
|
1225
|
+
if license_name not in self.valid_dataset_license_names:
|
|
1226
|
+
raise ValueError('Invalid license specified. Valid options are ' +
|
|
1227
|
+
str(self.valid_dataset_license_names))
|
|
1228
|
+
else:
|
|
1229
|
+
license_name = self.lookup_enum(DatasetLicenseGroup, license_name)
|
|
1230
|
+
|
|
1231
|
+
if int(page) <= 0:
|
|
1232
|
+
raise ValueError('Page number must be >= 1')
|
|
1233
|
+
|
|
1234
|
+
if max_size and min_size:
|
|
1235
|
+
if int(max_size) < int(min_size):
|
|
1236
|
+
raise ValueError('Max Size must be max_size >= min_size')
|
|
1237
|
+
if max_size and int(max_size) <= 0:
|
|
1238
|
+
raise ValueError('Max Size must be > 0')
|
|
1239
|
+
elif min_size and int(min_size) < 0:
|
|
1240
|
+
raise ValueError('Min Size must be >= 0')
|
|
1241
|
+
|
|
1242
|
+
group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_PUBLIC
|
|
1243
|
+
if mine:
|
|
1244
|
+
group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_MY
|
|
1245
|
+
if user:
|
|
1246
|
+
raise ValueError('Cannot specify both mine and a user')
|
|
1247
|
+
if user:
|
|
1248
|
+
group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_USER
|
|
1249
|
+
|
|
1250
|
+
with self.build_kaggle_client() as kaggle:
|
|
1251
|
+
request = ApiListDatasetsRequest()
|
|
1252
|
+
request.group = group
|
|
1253
|
+
request.sort_by = sort_by
|
|
1254
|
+
request.file_type = file_type
|
|
1255
|
+
request.license = license_name
|
|
1256
|
+
request.tag_ids = tag_ids
|
|
1257
|
+
request.search = search
|
|
1258
|
+
request.user = user
|
|
1259
|
+
request.page = page
|
|
1260
|
+
request.max_size = max_size
|
|
1261
|
+
request.min_size = min_size
|
|
1262
|
+
response = kaggle.datasets.dataset_api_client.list_datasets(request)
|
|
1263
|
+
return response.datasets
|
|
1264
|
+
|
|
1265
|
+
def dataset_list_cli(self,
|
|
1266
|
+
sort_by=None,
|
|
1267
|
+
size=None,
|
|
1268
|
+
file_type=None,
|
|
1269
|
+
license_name=None,
|
|
1270
|
+
tag_ids=None,
|
|
1271
|
+
search=None,
|
|
1272
|
+
user=None,
|
|
1273
|
+
mine=False,
|
|
1274
|
+
page=1,
|
|
1275
|
+
csv_display=False,
|
|
1276
|
+
max_size=None,
|
|
1277
|
+
min_size=None):
|
|
1278
|
+
""" A wrapper to dataset_list for the client. Additional parameters
|
|
1279
|
+
are described here, see dataset_list for others.
|
|
1280
|
+
|
|
1281
|
+
Parameters
|
|
1282
|
+
==========
|
|
1283
|
+
sort_by: how to sort the result, see valid_dataset_sort_bys for options
|
|
1284
|
+
size: DEPRECATED
|
|
1285
|
+
file_type: the format, see valid_dataset_file_types for string options
|
|
1286
|
+
license_name: string descriptor for license, see valid_dataset_license_names
|
|
1287
|
+
tag_ids: tag identifiers to filter the search
|
|
1288
|
+
search: a search term to use (default is empty string)
|
|
1289
|
+
user: username to filter the search to
|
|
1290
|
+
mine: boolean if True, group is changed to "my" to return personal
|
|
1291
|
+
page: the page to return (default is 1)
|
|
1292
|
+
csv_display: if True, print comma separated values instead of table
|
|
1293
|
+
max_size: the maximum size of the dataset to return (bytes)
|
|
1294
|
+
min_size: the minimum size of the dataset to return (bytes)
|
|
1295
|
+
"""
|
|
1296
|
+
datasets = self.dataset_list(sort_by, size, file_type, license_name,
|
|
1297
|
+
tag_ids, search, user, mine, page, max_size,
|
|
1298
|
+
min_size)
|
|
1299
|
+
if datasets:
|
|
1300
|
+
if csv_display:
|
|
1301
|
+
self.print_csv(datasets, self.dataset_fields, self.dataset_labels)
|
|
1302
|
+
else:
|
|
1303
|
+
self.print_table(datasets, self.dataset_fields, self.dataset_labels)
|
|
1304
|
+
else:
|
|
1305
|
+
print('No datasets found')
|
|
1306
|
+
|
|
1307
|
+
def dataset_metadata_prep(self, dataset, path):
|
|
1308
|
+
if dataset is None:
|
|
1309
|
+
raise ValueError('A dataset must be specified')
|
|
1310
|
+
if '/' in dataset:
|
|
1311
|
+
self.validate_dataset_string(dataset)
|
|
1312
|
+
dataset_urls = dataset.split('/')
|
|
1313
|
+
owner_slug = dataset_urls[0]
|
|
1314
|
+
dataset_slug = dataset_urls[1]
|
|
1315
|
+
else:
|
|
1316
|
+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
|
|
1317
|
+
dataset_slug = dataset
|
|
1318
|
+
|
|
1319
|
+
if path is None:
|
|
1320
|
+
effective_path = self.get_default_download_dir('datasets', owner_slug,
|
|
1321
|
+
dataset_slug)
|
|
1322
|
+
else:
|
|
1323
|
+
effective_path = path
|
|
1324
|
+
|
|
1325
|
+
return (owner_slug, dataset_slug, effective_path)
|
|
1326
|
+
|
|
1327
|
+
def dataset_metadata_update(self, dataset, path):
|
|
1328
|
+
(owner_slug, dataset_slug,
|
|
1329
|
+
effective_path) = self.dataset_metadata_prep(dataset, path)
|
|
1330
|
+
meta_file = self.get_dataset_metadata_file(effective_path)
|
|
1331
|
+
with open(meta_file, 'r') as f:
|
|
1332
|
+
s = json.load(f)
|
|
1333
|
+
metadata = json.loads(s)
|
|
1334
|
+
update_settings = DatasetSettings()
|
|
1335
|
+
update_settings.title = metadata.get('title') or ''
|
|
1336
|
+
update_settings.subtitle = metadata.get('subtitle') or ''
|
|
1337
|
+
update_settings.description = metadata.get('description') or ''
|
|
1338
|
+
update_settings.is_private = metadata.get('isPrivate') or False
|
|
1339
|
+
update_settings.licenses = [
|
|
1340
|
+
self._new_license(l['name']) for l in metadata['licenses']
|
|
1341
|
+
] if metadata.get('licenses') else []
|
|
1342
|
+
update_settings.keywords = metadata.get('keywords')
|
|
1343
|
+
update_settings.collaborators = [
|
|
1344
|
+
self._new_collaborator(c['username'], c['role'])
|
|
1345
|
+
for c in metadata['collaborators']
|
|
1346
|
+
] if metadata.get('collaborators') else []
|
|
1347
|
+
update_settings.data = metadata.get('data')
|
|
1348
|
+
request = ApiUpdateDatasetMetadataRequest()
|
|
1349
|
+
request.owner_slug = owner_slug
|
|
1350
|
+
request.dataset_slug = dataset_slug
|
|
1351
|
+
request.settings = update_settings
|
|
1352
|
+
with self.build_kaggle_client() as kaggle:
|
|
1353
|
+
response = kaggle.datasets.dataset_api_client.update_dataset_metadata(
|
|
1354
|
+
request)
|
|
1355
|
+
if len(response.errors) > 0:
|
|
1356
|
+
[print(e['message']) for e in response.errors]
|
|
1357
|
+
exit(1)
|
|
1358
|
+
|
|
1359
|
+
@staticmethod
|
|
1360
|
+
def _new_license(name):
|
|
1361
|
+
l = SettingsLicense()
|
|
1362
|
+
l.name = name
|
|
1363
|
+
return l
|
|
1364
|
+
|
|
1365
|
+
@staticmethod
|
|
1366
|
+
def _new_collaborator(name, role):
|
|
1367
|
+
u = DatasetCollaborator()
|
|
1368
|
+
u.username = name
|
|
1369
|
+
u.role = role
|
|
1370
|
+
return u
|
|
1371
|
+
|
|
1372
|
+
def dataset_metadata(self, dataset, path):
|
|
1373
|
+
(owner_slug, dataset_slug,
|
|
1374
|
+
effective_path) = self.dataset_metadata_prep(dataset, path)
|
|
1375
|
+
|
|
1376
|
+
if not os.path.exists(effective_path):
|
|
1377
|
+
os.makedirs(effective_path)
|
|
1378
|
+
|
|
1379
|
+
with self.build_kaggle_client() as kaggle:
|
|
1380
|
+
request = ApiGetDatasetMetadataRequest()
|
|
1381
|
+
request.owner_slug = owner_slug
|
|
1382
|
+
request.dataset_slug = dataset_slug
|
|
1383
|
+
response = kaggle.datasets.dataset_api_client.get_dataset_metadata(
|
|
1384
|
+
request)
|
|
1385
|
+
if response.error_message:
|
|
1386
|
+
raise Exception(response.error_message)
|
|
1387
|
+
|
|
1388
|
+
meta_file = os.path.join(effective_path, self.DATASET_METADATA_FILE)
|
|
1389
|
+
with open(meta_file, 'w') as f:
|
|
1390
|
+
json.dump(
|
|
1391
|
+
response.to_json(response.info),
|
|
1392
|
+
f,
|
|
1393
|
+
indent=2,
|
|
1394
|
+
default=lambda o: o.__dict__)
|
|
1395
|
+
|
|
1396
|
+
return meta_file
|
|
1397
|
+
|
|
1398
|
+
def dataset_metadata_cli(self, dataset, path, update, dataset_opt=None):
|
|
1399
|
+
dataset = dataset or dataset_opt
|
|
1400
|
+
if (update):
|
|
1401
|
+
print('updating dataset metadata')
|
|
1402
|
+
self.dataset_metadata_update(dataset, path)
|
|
1403
|
+
print('successfully updated dataset metadata')
|
|
1404
|
+
else:
|
|
1405
|
+
meta_file = self.dataset_metadata(dataset, path)
|
|
1406
|
+
print('Downloaded metadata to ' + meta_file)
|
|
1407
|
+
|
|
1408
|
+
def dataset_list_files(self, dataset, page_token=None, page_size=20):
|
|
1409
|
+
""" List files for a dataset.
|
|
1410
|
+
|
|
1411
|
+
Parameters
|
|
1412
|
+
==========
|
|
1413
|
+
dataset: the string identified of the dataset
|
|
1414
|
+
should be in format [owner]/[dataset-name]
|
|
1415
|
+
page_token: the page token for pagination
|
|
1416
|
+
page_size: the number of items per page
|
|
1417
|
+
"""
|
|
1418
|
+
if dataset is None:
|
|
1419
|
+
raise ValueError('A dataset must be specified')
|
|
1420
|
+
owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string(
|
|
1421
|
+
dataset)
|
|
1422
|
+
|
|
1423
|
+
with self.build_kaggle_client() as kaggle:
|
|
1424
|
+
request = ApiListDatasetFilesRequest()
|
|
1425
|
+
request.owner_slug = owner_slug
|
|
1426
|
+
request.dataset_slug = dataset_slug
|
|
1427
|
+
request.dataset_version_number = dataset_version_number
|
|
1428
|
+
request.page_token = page_token
|
|
1429
|
+
request.page_size = page_size
|
|
1430
|
+
response = kaggle.datasets.dataset_api_client.list_dataset_files(request)
|
|
1431
|
+
return response
|
|
1432
|
+
|
|
1433
|
+
def dataset_list_files_cli(self,
|
|
1434
|
+
dataset,
|
|
1435
|
+
dataset_opt=None,
|
|
1436
|
+
csv_display=False,
|
|
1437
|
+
page_token=None,
|
|
1438
|
+
page_size=20):
|
|
1439
|
+
""" A wrapper to dataset_list_files for the client
|
|
1440
|
+
(list files for a dataset).
|
|
1441
|
+
Parameters
|
|
1442
|
+
==========
|
|
1443
|
+
dataset: the string identified of the dataset
|
|
1444
|
+
should be in format [owner]/[dataset-name]
|
|
1445
|
+
dataset_opt: an alternative option to providing a dataset
|
|
1446
|
+
csv_display: if True, print comma separated values instead of table
|
|
1447
|
+
page_token: the page token for pagination
|
|
1448
|
+
page_size: the number of items per page
|
|
1449
|
+
"""
|
|
1450
|
+
dataset = dataset or dataset_opt
|
|
1451
|
+
result = self.dataset_list_files(dataset, page_token, page_size)
|
|
1452
|
+
|
|
1453
|
+
if result:
|
|
1454
|
+
if result.error_message:
|
|
1455
|
+
print(result.error_message)
|
|
1456
|
+
else:
|
|
1457
|
+
next_page_token = result.next_page_token
|
|
1458
|
+
if next_page_token:
|
|
1459
|
+
print('Next Page Token = {}'.format(next_page_token))
|
|
1460
|
+
fields = ['name', 'size', 'creationDate']
|
|
1461
|
+
ApiDatasetFile.size = ApiDatasetFile.total_bytes
|
|
1462
|
+
if csv_display:
|
|
1463
|
+
self.print_csv(result.files, fields)
|
|
1464
|
+
else:
|
|
1465
|
+
self.print_table(result.files, fields)
|
|
1466
|
+
else:
|
|
1467
|
+
print('No files found')
|
|
1468
|
+
|
|
1469
|
+
def dataset_status(self, dataset):
|
|
1470
|
+
""" Call to get the status of a dataset from the API.
|
|
1471
|
+
Parameters
|
|
1472
|
+
==========
|
|
1473
|
+
dataset: the string identifier of the dataset
|
|
1474
|
+
should be in format [owner]/[dataset-name]
|
|
1475
|
+
"""
|
|
1476
|
+
if dataset is None:
|
|
1477
|
+
raise ValueError('A dataset must be specified')
|
|
1478
|
+
if '/' in dataset:
|
|
1479
|
+
self.validate_dataset_string(dataset)
|
|
1480
|
+
dataset_urls = dataset.split('/')
|
|
1481
|
+
owner_slug = dataset_urls[0]
|
|
1482
|
+
dataset_slug = dataset_urls[1]
|
|
1483
|
+
else:
|
|
1484
|
+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
|
|
1485
|
+
dataset_slug = dataset
|
|
1486
|
+
|
|
1487
|
+
with self.build_kaggle_client() as kaggle:
|
|
1488
|
+
request = ApiGetDatasetStatusRequest()
|
|
1489
|
+
request.owner_slug = owner_slug
|
|
1490
|
+
request.dataset_slug = dataset_slug
|
|
1491
|
+
response = kaggle.datasets.dataset_api_client.get_dataset_status(request)
|
|
1492
|
+
return response.status.name.lower()
|
|
1493
|
+
|
|
1494
|
+
def dataset_status_cli(self, dataset, dataset_opt=None):
|
|
1495
|
+
""" A wrapper for client for dataset_status, with additional
|
|
1496
|
+
dataset_opt to get the status of a dataset from the API.
|
|
1497
|
+
Parameters
|
|
1498
|
+
==========
|
|
1499
|
+
dataset_opt: an alternative to dataset
|
|
1500
|
+
"""
|
|
1501
|
+
dataset = dataset or dataset_opt
|
|
1502
|
+
return self.dataset_status(dataset)
|
|
1503
|
+
|
|
1504
|
+
def dataset_download_file(self,
|
|
1505
|
+
dataset,
|
|
1506
|
+
file_name,
|
|
1507
|
+
path=None,
|
|
1508
|
+
force=False,
|
|
1509
|
+
quiet=True,
|
|
1510
|
+
licenses=[]):
|
|
1511
|
+
""" Download a single file for a dataset.
|
|
1512
|
+
|
|
1513
|
+
Parameters
|
|
1514
|
+
==========
|
|
1515
|
+
dataset: the string identified of the dataset
|
|
1516
|
+
should be in format [owner]/[dataset-name]
|
|
1517
|
+
file_name: the dataset configuration file
|
|
1518
|
+
path: if defined, download to this location
|
|
1519
|
+
force: force the download if the file already exists (default False)
|
|
1520
|
+
quiet: suppress verbose output (default is True)
|
|
1521
|
+
licenses: a list of license names, e.g. ['CC0-1.0']
|
|
1522
|
+
"""
|
|
1523
|
+
if '/' in dataset:
|
|
1524
|
+
self.validate_dataset_string(dataset)
|
|
1525
|
+
owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string(
|
|
1526
|
+
dataset)
|
|
1527
|
+
else:
|
|
1528
|
+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
|
|
1529
|
+
dataset_slug = dataset
|
|
1530
|
+
dataset_version_number = None
|
|
1531
|
+
|
|
1532
|
+
if path is None:
|
|
1533
|
+
effective_path = self.get_default_download_dir('datasets', owner_slug,
|
|
1534
|
+
dataset_slug)
|
|
1535
|
+
else:
|
|
1536
|
+
effective_path = path
|
|
1537
|
+
|
|
1538
|
+
self._print_dataset_url_and_license(owner_slug, dataset_slug,
|
|
1539
|
+
dataset_version_number, licenses)
|
|
1540
|
+
|
|
1541
|
+
with self.build_kaggle_client() as kaggle:
|
|
1542
|
+
request = ApiDownloadDatasetRequest()
|
|
1543
|
+
request.owner_slug = owner_slug
|
|
1544
|
+
request.dataset_slug = dataset_slug
|
|
1545
|
+
request.dataset_version_number = dataset_version_number
|
|
1546
|
+
request.file_name = file_name
|
|
1547
|
+
response = kaggle.datasets.dataset_api_client.download_dataset(request)
|
|
1548
|
+
url = response.history[0].url
|
|
1549
|
+
outfile = os.path.join(effective_path, url.split('?')[0].split('/')[-1])
|
|
1550
|
+
|
|
1551
|
+
if force or self.download_needed(response, outfile, quiet):
|
|
1552
|
+
self.download_file(response, outfile, quiet, not force)
|
|
1553
|
+
return True
|
|
1554
|
+
else:
|
|
1555
|
+
return False
|
|
1556
|
+
|
|
1557
|
+
def dataset_download_files(self,
|
|
1558
|
+
dataset,
|
|
1559
|
+
path=None,
|
|
1560
|
+
force=False,
|
|
1561
|
+
quiet=True,
|
|
1562
|
+
unzip=False,
|
|
1563
|
+
licenses=[]):
|
|
1564
|
+
""" Download all files for a dataset.
|
|
1565
|
+
|
|
1566
|
+
Parameters
|
|
1567
|
+
==========
|
|
1568
|
+
dataset: the string identified of the dataset
|
|
1569
|
+
should be in format [owner]/[dataset-name]
|
|
1570
|
+
path: the path to download the dataset to
|
|
1571
|
+
force: force the download if the file already exists (default False)
|
|
1572
|
+
quiet: suppress verbose output (default is True)
|
|
1573
|
+
unzip: if True, unzip files upon download (default is False)
|
|
1574
|
+
licenses: a list of license names, e.g. ['CC0-1.0']
|
|
1575
|
+
"""
|
|
1576
|
+
if dataset is None:
|
|
1577
|
+
raise ValueError('A dataset must be specified')
|
|
1578
|
+
owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string(
|
|
1579
|
+
dataset)
|
|
1580
|
+
if path is None:
|
|
1581
|
+
effective_path = self.get_default_download_dir('datasets', owner_slug,
|
|
1582
|
+
dataset_slug)
|
|
1583
|
+
else:
|
|
1584
|
+
effective_path = path
|
|
1585
|
+
|
|
1586
|
+
self._print_dataset_url_and_license(owner_slug, dataset_slug,
|
|
1587
|
+
dataset_version_number, licenses)
|
|
1588
|
+
|
|
1589
|
+
with self.build_kaggle_client() as kaggle:
|
|
1590
|
+
request = ApiDownloadDatasetRequest()
|
|
1591
|
+
request.owner_slug = owner_slug
|
|
1592
|
+
request.dataset_slug = dataset_slug
|
|
1593
|
+
request.dataset_version_number = dataset_version_number
|
|
1594
|
+
response = kaggle.datasets.dataset_api_client.download_dataset(request)
|
|
1595
|
+
|
|
1596
|
+
outfile = os.path.join(effective_path, dataset_slug + '.zip')
|
|
1597
|
+
if force or self.download_needed(response, outfile, quiet):
|
|
1598
|
+
self.download_file(response, outfile, quiet, not force)
|
|
1599
|
+
downloaded = True
|
|
1600
|
+
else:
|
|
1601
|
+
downloaded = False
|
|
1602
|
+
|
|
1603
|
+
if downloaded:
|
|
1604
|
+
outfile = os.path.join(effective_path, dataset_slug + '.zip')
|
|
1605
|
+
if unzip:
|
|
1606
|
+
try:
|
|
1607
|
+
with zipfile.ZipFile(outfile) as z:
|
|
1608
|
+
z.extractall(effective_path)
|
|
1609
|
+
except zipfile.BadZipFile as e:
|
|
1610
|
+
raise ValueError(
|
|
1611
|
+
f"The file {outfile} is corrupted or not a valid zip file. "
|
|
1612
|
+
"Please report this issue at https://www.github.com/kaggle/kaggle-api"
|
|
1613
|
+
)
|
|
1614
|
+
except FileNotFoundError:
|
|
1615
|
+
raise FileNotFoundError(
|
|
1616
|
+
f"The file {outfile} was not found. "
|
|
1617
|
+
"Please report this issue at https://www.github.com/kaggle/kaggle-api"
|
|
1618
|
+
)
|
|
1619
|
+
except Exception as e:
|
|
1620
|
+
raise RuntimeError(
|
|
1621
|
+
f"An unexpected error occurred: {e}. "
|
|
1622
|
+
"Please report this issue at https://www.github.com/kaggle/kaggle-api"
|
|
1623
|
+
)
|
|
1624
|
+
|
|
1625
|
+
try:
|
|
1626
|
+
os.remove(outfile)
|
|
1627
|
+
except OSError as e:
|
|
1628
|
+
print('Could not delete zip file, got %s' % e)
|
|
1629
|
+
|
|
1630
|
+
def _print_dataset_url_and_license(self, owner_slug, dataset_slug,
|
|
1631
|
+
dataset_version_number, licenses):
|
|
1632
|
+
if dataset_version_number is None:
|
|
1633
|
+
print('Dataset URL: https://www.kaggle.com/datasets/%s/%s' %
|
|
1634
|
+
(owner_slug, dataset_slug))
|
|
1635
|
+
else:
|
|
1636
|
+
print('Dataset URL: https://www.kaggle.com/datasets/%s/%s/versions/%s' %
|
|
1637
|
+
(owner_slug, dataset_slug, dataset_version_number))
|
|
1638
|
+
|
|
1639
|
+
if len(licenses) > 0:
|
|
1640
|
+
print('License(s): %s' % (','.join(licenses)))
|
|
1641
|
+
|
|
1642
|
+
def dataset_download_cli(self,
|
|
1643
|
+
dataset,
|
|
1644
|
+
dataset_opt=None,
|
|
1645
|
+
file_name=None,
|
|
1646
|
+
path=None,
|
|
1647
|
+
unzip=False,
|
|
1648
|
+
force=False,
|
|
1649
|
+
quiet=False):
|
|
1650
|
+
""" Client wrapper for dataset_download_files and download dataset file,
|
|
1651
|
+
either for a specific file (when file_name is provided),
|
|
1652
|
+
or all files for a dataset (plural).
|
|
1653
|
+
|
|
1654
|
+
Parameters
|
|
1655
|
+
==========
|
|
1656
|
+
dataset: the string identified of the dataset
|
|
1657
|
+
should be in format [owner]/[dataset-name]
|
|
1658
|
+
dataset_opt: an alternative option to providing a dataset
|
|
1659
|
+
file_name: the dataset configuration file
|
|
1660
|
+
path: the path to download the dataset to
|
|
1661
|
+
force: force the download if the file already exists (default False)
|
|
1662
|
+
quiet: suppress verbose output (default is False)
|
|
1663
|
+
unzip: if True, unzip files upon download (default is False)
|
|
1664
|
+
"""
|
|
1665
|
+
dataset = dataset or dataset_opt
|
|
1666
|
+
|
|
1667
|
+
owner_slug, dataset_slug, _ = self.split_dataset_string(dataset)
|
|
1668
|
+
metadata = self.process_response(
|
|
1669
|
+
self.metadata_get_with_http_info(owner_slug, dataset_slug))
|
|
1670
|
+
|
|
1671
|
+
if 'info' in metadata and 'licenses' in metadata['info']:
|
|
1672
|
+
# license_objs format is like: [{ 'name': 'CC0-1.0' }]
|
|
1673
|
+
license_objs = metadata['info']['licenses']
|
|
1674
|
+
licenses = [
|
|
1675
|
+
license_obj['name']
|
|
1676
|
+
for license_obj in license_objs
|
|
1677
|
+
if 'name' in license_obj
|
|
1678
|
+
]
|
|
1679
|
+
else:
|
|
1680
|
+
licenses = [
|
|
1681
|
+
'Error retrieving license. Please visit the Dataset URL to view license information.'
|
|
1682
|
+
]
|
|
1683
|
+
|
|
1684
|
+
if file_name is None:
|
|
1685
|
+
self.dataset_download_files(
|
|
1686
|
+
dataset,
|
|
1687
|
+
path=path,
|
|
1688
|
+
unzip=unzip,
|
|
1689
|
+
force=force,
|
|
1690
|
+
quiet=quiet,
|
|
1691
|
+
licenses=licenses)
|
|
1692
|
+
else:
|
|
1693
|
+
self.dataset_download_file(
|
|
1694
|
+
dataset,
|
|
1695
|
+
file_name,
|
|
1696
|
+
path=path,
|
|
1697
|
+
force=force,
|
|
1698
|
+
quiet=quiet,
|
|
1699
|
+
licenses=licenses)
|
|
1700
|
+
|
|
1701
|
+
def _upload_blob(self, path, quiet, blob_type, upload_context):
|
|
1702
|
+
""" Upload a file.
|
|
1703
|
+
|
|
1704
|
+
Parameters
|
|
1705
|
+
==========
|
|
1706
|
+
path: the complete path to upload
|
|
1707
|
+
quiet: suppress verbose output (default is False)
|
|
1708
|
+
blob_type (ApiBlobType): To which entity the file/blob refers
|
|
1709
|
+
upload_context (ResumableUploadContext): Context for resumable uploads
|
|
1710
|
+
"""
|
|
1711
|
+
file_name = os.path.basename(path)
|
|
1712
|
+
content_length = os.path.getsize(path)
|
|
1713
|
+
last_modified_epoch_seconds = int(os.path.getmtime(path))
|
|
1714
|
+
|
|
1715
|
+
start_blob_upload_request = ApiStartBlobUploadRequest()
|
|
1716
|
+
start_blob_upload_request.type = blob_type
|
|
1717
|
+
start_blob_upload_request.name = file_name
|
|
1718
|
+
start_blob_upload_request.content_length = content_length
|
|
1719
|
+
start_blob_upload_request.last_modified_epoch_seconds = last_modified_epoch_seconds
|
|
1720
|
+
|
|
1721
|
+
file_upload = upload_context.new_resumable_file_upload(
|
|
1722
|
+
path, start_blob_upload_request)
|
|
1723
|
+
|
|
1724
|
+
for i in range(0, self.MAX_UPLOAD_RESUME_ATTEMPTS):
|
|
1725
|
+
if file_upload.upload_complete:
|
|
1726
|
+
return file_upload
|
|
1727
|
+
|
|
1728
|
+
if not file_upload.can_resume:
|
|
1729
|
+
# Initiate upload on Kaggle backend to get the url and token.
|
|
1730
|
+
with self.build_kaggle_client() as kaggle:
|
|
1731
|
+
method = kaggle.blobs.blob_api_client.start_blob_upload
|
|
1732
|
+
start_blob_upload_response = self.with_retry(method)(
|
|
1733
|
+
file_upload.start_blob_upload_request)
|
|
1734
|
+
file_upload.upload_initiated(start_blob_upload_response)
|
|
1735
|
+
|
|
1736
|
+
upload_result = self.upload_complete(
|
|
1737
|
+
path,
|
|
1738
|
+
file_upload.start_blob_upload_response.create_url,
|
|
1739
|
+
quiet,
|
|
1740
|
+
resume=file_upload.can_resume)
|
|
1741
|
+
if upload_result == ResumableUploadResult.INCOMPLETE:
|
|
1742
|
+
continue # Continue (i.e., retry/resume) only if the upload is incomplete.
|
|
1743
|
+
|
|
1744
|
+
if upload_result == ResumableUploadResult.COMPLETE:
|
|
1745
|
+
file_upload.upload_completed()
|
|
1746
|
+
break
|
|
1747
|
+
|
|
1748
|
+
return file_upload.get_token()
|
|
1749
|
+
|
|
1750
|
+
def dataset_create_version(self,
|
|
1751
|
+
folder,
|
|
1752
|
+
version_notes,
|
|
1753
|
+
quiet=False,
|
|
1754
|
+
convert_to_csv=True,
|
|
1755
|
+
delete_old_versions=False,
|
|
1756
|
+
dir_mode='skip'):
|
|
1757
|
+
""" Create a version of a dataset.
|
|
1758
|
+
|
|
1759
|
+
Parameters
|
|
1760
|
+
==========
|
|
1761
|
+
folder: the folder with the dataset configuration / data files
|
|
1762
|
+
version_notes: notes to add for the version
|
|
1763
|
+
quiet: suppress verbose output (default is False)
|
|
1764
|
+
convert_to_csv: on upload, if data should be converted to csv
|
|
1765
|
+
delete_old_versions: if True, do that (default False)
|
|
1766
|
+
dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
1767
|
+
"""
|
|
1768
|
+
if not os.path.isdir(folder):
|
|
1769
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
1770
|
+
|
|
1771
|
+
meta_file = self.get_dataset_metadata_file(folder)
|
|
1772
|
+
|
|
1773
|
+
# read json
|
|
1774
|
+
with open(meta_file) as f:
|
|
1775
|
+
meta_data = json.load(f)
|
|
1776
|
+
ref = self.get_or_default(meta_data, 'id', None)
|
|
1777
|
+
id_no = self.get_or_default(meta_data, 'id_no', None)
|
|
1778
|
+
if not ref and not id_no:
|
|
1779
|
+
raise ValueError('ID or slug must be specified in the metadata')
|
|
1780
|
+
elif ref and ref == self.config_values[
|
|
1781
|
+
self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE':
|
|
1782
|
+
raise ValueError(
|
|
1783
|
+
'Default slug detected, please change values before uploading')
|
|
1784
|
+
|
|
1785
|
+
subtitle = meta_data.get('subtitle')
|
|
1786
|
+
if subtitle and (len(subtitle) < 20 or len(subtitle) > 80):
|
|
1787
|
+
raise ValueError('Subtitle length must be between 20 and 80 characters')
|
|
1788
|
+
resources = meta_data.get('resources')
|
|
1789
|
+
if resources:
|
|
1790
|
+
self.validate_resources(folder, resources)
|
|
1791
|
+
|
|
1792
|
+
description = meta_data.get('description')
|
|
1793
|
+
keywords = self.get_or_default(meta_data, 'keywords', [])
|
|
1794
|
+
|
|
1795
|
+
body = ApiCreateDatasetVersionRequestBody()
|
|
1796
|
+
body.version_notes = version_notes
|
|
1797
|
+
body.subtitle = subtitle
|
|
1798
|
+
body.description = description
|
|
1799
|
+
body.files = []
|
|
1800
|
+
body.category_ids = keywords
|
|
1801
|
+
body.delete_old_versions = delete_old_versions
|
|
1802
|
+
|
|
1803
|
+
with self.build_kaggle_client() as kaggle:
|
|
1804
|
+
if id_no:
|
|
1805
|
+
request = ApiCreateDatasetVersionByIdRequest()
|
|
1806
|
+
request.id = id_no
|
|
1807
|
+
message = kaggle.datasets.dataset_api_client.create_dataset_version_by_id
|
|
1808
|
+
else:
|
|
1809
|
+
self.validate_dataset_string(ref)
|
|
1810
|
+
ref_list = ref.split('/')
|
|
1811
|
+
owner_slug = ref_list[0]
|
|
1812
|
+
dataset_slug = ref_list[1]
|
|
1813
|
+
request = ApiCreateDatasetVersionRequest()
|
|
1814
|
+
request.owner_slug = owner_slug
|
|
1815
|
+
request.dataset_slug = dataset_slug
|
|
1816
|
+
message = kaggle.datasets.dataset_api_client.create_dataset_version
|
|
1817
|
+
request.body = body
|
|
1818
|
+
with ResumableUploadContext() as upload_context:
|
|
1819
|
+
self.upload_files(body, resources, folder, ApiBlobType.DATASET,
|
|
1820
|
+
upload_context, quiet, dir_mode)
|
|
1821
|
+
request.body.files = [
|
|
1822
|
+
self._api_dataset_new_file(file) for file in request.body.files
|
|
1823
|
+
]
|
|
1824
|
+
response = self.with_retry(message)(request)
|
|
1825
|
+
return response
|
|
1826
|
+
|
|
1827
|
+
def _api_dataset_new_file(self, file):
|
|
1828
|
+
# TODO Eliminate the need for this conversion
|
|
1829
|
+
f = ApiDatasetNewFile()
|
|
1830
|
+
f.token = file.token
|
|
1831
|
+
return f
|
|
1832
|
+
|
|
1833
|
+
def dataset_create_version_cli(self,
|
|
1834
|
+
folder,
|
|
1835
|
+
version_notes,
|
|
1836
|
+
quiet=False,
|
|
1837
|
+
convert_to_csv=True,
|
|
1838
|
+
delete_old_versions=False,
|
|
1839
|
+
dir_mode='skip'):
|
|
1840
|
+
""" client wrapper for creating a version of a dataset
|
|
1841
|
+
Parameters
|
|
1842
|
+
==========
|
|
1843
|
+
folder: the folder with the dataset configuration / data files
|
|
1844
|
+
version_notes: notes to add for the version
|
|
1845
|
+
quiet: suppress verbose output (default is False)
|
|
1846
|
+
convert_to_csv: on upload, if data should be converted to csv
|
|
1847
|
+
delete_old_versions: if True, do that (default False)
|
|
1848
|
+
dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
1849
|
+
"""
|
|
1850
|
+
folder = folder or os.getcwd()
|
|
1851
|
+
result = self.dataset_create_version(
|
|
1852
|
+
folder,
|
|
1853
|
+
version_notes,
|
|
1854
|
+
quiet=quiet,
|
|
1855
|
+
convert_to_csv=convert_to_csv,
|
|
1856
|
+
delete_old_versions=delete_old_versions,
|
|
1857
|
+
dir_mode=dir_mode)
|
|
1858
|
+
|
|
1859
|
+
if result is None:
|
|
1860
|
+
print('Dataset version creation error: See previous output')
|
|
1861
|
+
elif result.invalidTags:
|
|
1862
|
+
print(('The following are not valid tags and could not be added to '
|
|
1863
|
+
'the dataset: ') + str(result.invalidTags))
|
|
1864
|
+
elif result.status.lower() == 'ok':
|
|
1865
|
+
print('Dataset version is being created. Please check progress at ' +
|
|
1866
|
+
result.url)
|
|
1867
|
+
else:
|
|
1868
|
+
print('Dataset version creation error: ' + result.error)
|
|
1869
|
+
|
|
1870
|
+
def dataset_initialize(self, folder):
|
|
1871
|
+
""" initialize a folder with a a dataset configuration (metadata) file
|
|
1872
|
+
|
|
1873
|
+
Parameters
|
|
1874
|
+
==========
|
|
1875
|
+
folder: the folder to initialize the metadata file in
|
|
1876
|
+
"""
|
|
1877
|
+
if not os.path.isdir(folder):
|
|
1878
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
1879
|
+
|
|
1880
|
+
ref = self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE'
|
|
1881
|
+
licenses = []
|
|
1882
|
+
default_license = {'name': 'CC0-1.0'}
|
|
1883
|
+
licenses.append(default_license)
|
|
1884
|
+
|
|
1885
|
+
meta_data = {'title': 'INSERT_TITLE_HERE', 'id': ref, 'licenses': licenses}
|
|
1886
|
+
meta_file = os.path.join(folder, self.DATASET_METADATA_FILE)
|
|
1887
|
+
with open(meta_file, 'w') as f:
|
|
1888
|
+
json.dump(meta_data, f, indent=2)
|
|
1889
|
+
|
|
1890
|
+
print('Data package template written to: ' + meta_file)
|
|
1891
|
+
return meta_file
|
|
1892
|
+
|
|
1893
|
+
def dataset_initialize_cli(self, folder=None):
|
|
1894
|
+
folder = folder or os.getcwd()
|
|
1895
|
+
self.dataset_initialize(folder)
|
|
1896
|
+
|
|
1897
|
+
def dataset_create_new(self,
|
|
1898
|
+
folder,
|
|
1899
|
+
public=False,
|
|
1900
|
+
quiet=False,
|
|
1901
|
+
convert_to_csv=True,
|
|
1902
|
+
dir_mode='skip'):
|
|
1903
|
+
""" Create a new dataset, meaning the same as creating a version but
|
|
1904
|
+
with extra metadata like license and user/owner.
|
|
1905
|
+
|
|
1906
|
+
Parameters
|
|
1907
|
+
==========
|
|
1908
|
+
folder: the folder to get the metadata file from
|
|
1909
|
+
public: should the dataset be public?
|
|
1910
|
+
quiet: suppress verbose output (default is False)
|
|
1911
|
+
convert_to_csv: if True, convert data to comma separated value
|
|
1912
|
+
dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
1913
|
+
"""
|
|
1914
|
+
if not os.path.isdir(folder):
|
|
1915
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
1916
|
+
|
|
1917
|
+
meta_file = self.get_dataset_metadata_file(folder)
|
|
1918
|
+
|
|
1919
|
+
# read json
|
|
1920
|
+
with open(meta_file) as f:
|
|
1921
|
+
meta_data = json.load(f)
|
|
1922
|
+
ref = self.get_or_fail(meta_data, 'id')
|
|
1923
|
+
title = self.get_or_fail(meta_data, 'title')
|
|
1924
|
+
licenses = self.get_or_fail(meta_data, 'licenses')
|
|
1925
|
+
ref_list = ref.split('/')
|
|
1926
|
+
owner_slug = ref_list[0]
|
|
1927
|
+
dataset_slug = ref_list[1]
|
|
1928
|
+
|
|
1929
|
+
# validations
|
|
1930
|
+
if ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE':
|
|
1931
|
+
raise ValueError(
|
|
1932
|
+
'Default slug detected, please change values before uploading')
|
|
1933
|
+
if title == 'INSERT_TITLE_HERE':
|
|
1934
|
+
raise ValueError(
|
|
1935
|
+
'Default title detected, please change values before uploading')
|
|
1936
|
+
if len(licenses) != 1:
|
|
1937
|
+
raise ValueError('Please specify exactly one license')
|
|
1938
|
+
if len(dataset_slug) < 6 or len(dataset_slug) > 50:
|
|
1939
|
+
raise ValueError('The dataset slug must be between 6 and 50 characters')
|
|
1940
|
+
if len(title) < 6 or len(title) > 50:
|
|
1941
|
+
raise ValueError('The dataset title must be between 6 and 50 characters')
|
|
1942
|
+
resources = meta_data.get('resources')
|
|
1943
|
+
if resources:
|
|
1944
|
+
self.validate_resources(folder, resources)
|
|
1945
|
+
|
|
1946
|
+
license_name = self.get_or_fail(licenses[0], 'name')
|
|
1947
|
+
description = meta_data.get('description')
|
|
1948
|
+
keywords = self.get_or_default(meta_data, 'keywords', [])
|
|
1949
|
+
|
|
1950
|
+
subtitle = meta_data.get('subtitle')
|
|
1951
|
+
if subtitle and (len(subtitle) < 20 or len(subtitle) > 80):
|
|
1952
|
+
raise ValueError('Subtitle length must be between 20 and 80 characters')
|
|
1953
|
+
|
|
1954
|
+
request = ApiCreateDatasetRequest()
|
|
1955
|
+
request.title = title
|
|
1956
|
+
request.slug = dataset_slug
|
|
1957
|
+
request.owner_slug = owner_slug
|
|
1958
|
+
request.license_name = license_name
|
|
1959
|
+
request.subtitle = subtitle
|
|
1960
|
+
request.description = description
|
|
1961
|
+
request.files = []
|
|
1962
|
+
request.is_private = not public
|
|
1963
|
+
# request.convert_to_csv=convert_to_csv
|
|
1964
|
+
request.category_ids = keywords
|
|
1965
|
+
|
|
1966
|
+
with ResumableUploadContext() as upload_context:
|
|
1967
|
+
self.upload_files(request, resources, folder, ApiBlobType.DATASET,
|
|
1968
|
+
upload_context, quiet, dir_mode)
|
|
1969
|
+
|
|
1970
|
+
with self.build_kaggle_client() as kaggle:
|
|
1971
|
+
retry_request = ApiCreateDatasetRequest()
|
|
1972
|
+
retry_request.title = title
|
|
1973
|
+
retry_request.slug = dataset_slug
|
|
1974
|
+
retry_request.owner_slug = owner_slug
|
|
1975
|
+
retry_request.license_name = license_name
|
|
1976
|
+
retry_request.subtitle = subtitle
|
|
1977
|
+
retry_request.description = description
|
|
1978
|
+
retry_request.files = [
|
|
1979
|
+
self._api_dataset_new_file(file) for file in request.files
|
|
1980
|
+
]
|
|
1981
|
+
retry_request.is_private = not public
|
|
1982
|
+
retry_request.category_ids = keywords
|
|
1983
|
+
response = self.with_retry(
|
|
1984
|
+
kaggle.datasets.dataset_api_client.create_dataset)(
|
|
1985
|
+
retry_request)
|
|
1986
|
+
return response
|
|
1987
|
+
|
|
1988
|
+
def dataset_create_new_cli(self,
|
|
1989
|
+
folder=None,
|
|
1990
|
+
public=False,
|
|
1991
|
+
quiet=False,
|
|
1992
|
+
convert_to_csv=True,
|
|
1993
|
+
dir_mode='skip'):
|
|
1994
|
+
""" client wrapper for creating a new dataset
|
|
1995
|
+
Parameters
|
|
1996
|
+
==========
|
|
1997
|
+
folder: the folder to get the metadata file from
|
|
1998
|
+
public: should the dataset be public?
|
|
1999
|
+
quiet: suppress verbose output (default is False)
|
|
2000
|
+
convert_to_csv: if True, convert data to comma separated value
|
|
2001
|
+
dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
2002
|
+
"""
|
|
2003
|
+
folder = folder or os.getcwd()
|
|
2004
|
+
result = self.dataset_create_new(folder, public, quiet, convert_to_csv,
|
|
2005
|
+
dir_mode)
|
|
2006
|
+
if result.invalidTags:
|
|
2007
|
+
print('The following are not valid tags and could not be added to '
|
|
2008
|
+
'the dataset: ' + str(result.invalidTags))
|
|
2009
|
+
if result.status.lower() == 'ok':
|
|
2010
|
+
if public:
|
|
2011
|
+
print('Your public Dataset is being created. Please check '
|
|
2012
|
+
'progress at ' + result.url)
|
|
2013
|
+
else:
|
|
2014
|
+
print('Your private Dataset is being created. Please check '
|
|
2015
|
+
'progress at ' + result.url)
|
|
2016
|
+
else:
|
|
2017
|
+
print('Dataset creation error: ' + result.error)
|
|
2018
|
+
|
|
2019
|
+
def download_file(self,
|
|
2020
|
+
response,
|
|
2021
|
+
outfile,
|
|
2022
|
+
http_client,
|
|
2023
|
+
quiet=True,
|
|
2024
|
+
resume=False,
|
|
2025
|
+
chunk_size=1048576):
|
|
2026
|
+
""" download a file to an output file based on a chunk size
|
|
2027
|
+
|
|
2028
|
+
Parameters
|
|
2029
|
+
==========
|
|
2030
|
+
response: the response to download
|
|
2031
|
+
outfile: the output file to download to
|
|
2032
|
+
http_client: the Kaggle http client to use
|
|
2033
|
+
quiet: suppress verbose output (default is True)
|
|
2034
|
+
chunk_size: the size of the chunk to stream
|
|
2035
|
+
resume: whether to resume an existing download
|
|
2036
|
+
"""
|
|
2037
|
+
|
|
2038
|
+
outpath = os.path.dirname(outfile)
|
|
2039
|
+
if not os.path.exists(outpath):
|
|
2040
|
+
os.makedirs(outpath)
|
|
2041
|
+
size = int(response.headers['Content-Length'])
|
|
2042
|
+
size_read = 0
|
|
2043
|
+
open_mode = 'wb'
|
|
2044
|
+
last_modified = response.headers.get('Last-Modified')
|
|
2045
|
+
if last_modified is None:
|
|
2046
|
+
remote_date = datetime.now()
|
|
2047
|
+
else:
|
|
2048
|
+
remote_date = datetime.strptime(response.headers['Last-Modified'],
|
|
2049
|
+
'%a, %d %b %Y %H:%M:%S %Z')
|
|
2050
|
+
remote_date_timestamp = time.mktime(remote_date.timetuple())
|
|
2051
|
+
|
|
2052
|
+
if not quiet:
|
|
2053
|
+
print('Downloading ' + os.path.basename(outfile) + ' to ' + outpath)
|
|
2054
|
+
|
|
2055
|
+
file_exists = os.path.isfile(outfile)
|
|
2056
|
+
resumable = 'Accept-Ranges' in response.headers and response.headers[
|
|
2057
|
+
'Accept-Ranges'] == 'bytes'
|
|
2058
|
+
|
|
2059
|
+
if resume and resumable and file_exists:
|
|
2060
|
+
size_read = os.path.getsize(outfile)
|
|
2061
|
+
open_mode = 'ab'
|
|
2062
|
+
|
|
2063
|
+
if not quiet:
|
|
2064
|
+
print("... resuming from %d bytes (%d bytes left) ..." % (
|
|
2065
|
+
size_read,
|
|
2066
|
+
size - size_read,
|
|
2067
|
+
))
|
|
2068
|
+
|
|
2069
|
+
request_history = response.history[0]
|
|
2070
|
+
response = http_client.call(
|
|
2071
|
+
request_history.request.method,
|
|
2072
|
+
request_history.headers['location'],
|
|
2073
|
+
headers={'Range': 'bytes=%d-' % (size_read,)},
|
|
2074
|
+
_preload_content=False)
|
|
2075
|
+
|
|
2076
|
+
with tqdm(
|
|
2077
|
+
total=size,
|
|
2078
|
+
initial=size_read,
|
|
2079
|
+
unit='B',
|
|
2080
|
+
unit_scale=True,
|
|
2081
|
+
unit_divisor=1024,
|
|
2082
|
+
disable=quiet) as pbar:
|
|
2083
|
+
with open(outfile, open_mode) as out:
|
|
2084
|
+
# TODO: Delete this test after all API methods are converted.
|
|
2085
|
+
if type(response).__name__ == 'HTTPResponse':
|
|
2086
|
+
while True:
|
|
2087
|
+
data = response.read(chunk_size)
|
|
2088
|
+
if not data:
|
|
2089
|
+
break
|
|
2090
|
+
out.write(data)
|
|
2091
|
+
os.utime(
|
|
2092
|
+
outfile,
|
|
2093
|
+
times=(remote_date_timestamp - 1, remote_date_timestamp - 1))
|
|
2094
|
+
size_read = min(size, size_read + chunk_size)
|
|
2095
|
+
pbar.update(len(data))
|
|
2096
|
+
else:
|
|
2097
|
+
for data in response.iter_content(chunk_size):
|
|
2098
|
+
if not data:
|
|
2099
|
+
break
|
|
2100
|
+
out.write(data)
|
|
2101
|
+
os.utime(
|
|
2102
|
+
outfile,
|
|
2103
|
+
times=(remote_date_timestamp - 1, remote_date_timestamp - 1))
|
|
2104
|
+
size_read = min(size, size_read + chunk_size)
|
|
2105
|
+
pbar.update(len(data))
|
|
2106
|
+
if not quiet:
|
|
2107
|
+
print('\n', end='')
|
|
2108
|
+
|
|
2109
|
+
os.utime(outfile, times=(remote_date_timestamp, remote_date_timestamp))
|
|
2110
|
+
|
|
2111
|
+
def kernels_list(self,
|
|
2112
|
+
page=1,
|
|
2113
|
+
page_size=20,
|
|
2114
|
+
dataset=None,
|
|
2115
|
+
competition=None,
|
|
2116
|
+
parent_kernel=None,
|
|
2117
|
+
search=None,
|
|
2118
|
+
mine=False,
|
|
2119
|
+
user=None,
|
|
2120
|
+
language=None,
|
|
2121
|
+
kernel_type=None,
|
|
2122
|
+
output_type=None,
|
|
2123
|
+
sort_by=None):
|
|
2124
|
+
""" List kernels based on a set of search criteria.
|
|
2125
|
+
|
|
2126
|
+
Parameters
|
|
2127
|
+
==========
|
|
2128
|
+
page: the page of results to return (default is 1)
|
|
2129
|
+
page_size: results per page (default is 20)
|
|
2130
|
+
dataset: if defined, filter to this dataset (default None)
|
|
2131
|
+
competition: if defined, filter to this competition (default None)
|
|
2132
|
+
parent_kernel: if defined, filter to those with specified parent
|
|
2133
|
+
search: a custom search string to pass to the list query
|
|
2134
|
+
mine: if true, group is specified as "my" to return personal kernels
|
|
2135
|
+
user: filter results to a specific user
|
|
2136
|
+
language: the programming language of the kernel
|
|
2137
|
+
kernel_type: the type of kernel, one of valid_list_kernel_types (str)
|
|
2138
|
+
output_type: the output type, one of valid_list_output_types (str)
|
|
2139
|
+
sort_by: if defined, sort results by this string (valid_list_sort_by)
|
|
2140
|
+
"""
|
|
2141
|
+
if int(page) <= 0:
|
|
2142
|
+
raise ValueError('Page number must be >= 1')
|
|
2143
|
+
|
|
2144
|
+
page_size = int(page_size)
|
|
2145
|
+
if page_size <= 0:
|
|
2146
|
+
raise ValueError('Page size must be >= 1')
|
|
2147
|
+
if page_size > 100:
|
|
2148
|
+
page_size = 100
|
|
2149
|
+
|
|
2150
|
+
if language and language not in self.valid_list_languages:
|
|
2151
|
+
raise ValueError('Invalid language specified. Valid options are ' +
|
|
2152
|
+
str(self.valid_list_languages))
|
|
2153
|
+
|
|
2154
|
+
if kernel_type and kernel_type not in self.valid_list_kernel_types:
|
|
2155
|
+
raise ValueError('Invalid kernel type specified. Valid options are ' +
|
|
2156
|
+
str(self.valid_list_kernel_types))
|
|
2157
|
+
|
|
2158
|
+
if output_type and output_type not in self.valid_list_output_types:
|
|
2159
|
+
raise ValueError('Invalid output type specified. Valid options are ' +
|
|
2160
|
+
str(self.valid_list_output_types))
|
|
2161
|
+
|
|
2162
|
+
if sort_by:
|
|
2163
|
+
if sort_by not in self.valid_list_sort_by:
|
|
2164
|
+
raise ValueError('Invalid sort by type specified. Valid options are ' +
|
|
2165
|
+
str(self.valid_list_sort_by))
|
|
2166
|
+
if sort_by == 'relevance' and search == '':
|
|
2167
|
+
raise ValueError('Cannot sort by relevance without a search term.')
|
|
2168
|
+
sort_by = self.lookup_enum(KernelsListSortType, sort_by)
|
|
2169
|
+
else:
|
|
2170
|
+
sort_by = KernelsListSortType.HOTNESS
|
|
2171
|
+
|
|
2172
|
+
self.validate_dataset_string(dataset)
|
|
2173
|
+
self.validate_kernel_string(parent_kernel)
|
|
2174
|
+
|
|
2175
|
+
group = 'everyone'
|
|
2176
|
+
if mine:
|
|
2177
|
+
group = 'profile'
|
|
2178
|
+
group = self.lookup_enum(KernelsListViewType, group)
|
|
2179
|
+
|
|
2180
|
+
with self.build_kaggle_client() as kaggle:
|
|
2181
|
+
request = ApiListKernelsRequest()
|
|
2182
|
+
request.page = page
|
|
2183
|
+
request.page_size = page_size
|
|
2184
|
+
request.group = group
|
|
2185
|
+
request.user = user or ''
|
|
2186
|
+
request.language = language or 'all'
|
|
2187
|
+
request.kernel_type = kernel_type or 'all'
|
|
2188
|
+
request.output_type = output_type or 'all'
|
|
2189
|
+
request.sort_by = sort_by
|
|
2190
|
+
request.dataset = dataset or ''
|
|
2191
|
+
request.competition = competition or ''
|
|
2192
|
+
request.parent_kernel = parent_kernel or ''
|
|
2193
|
+
request.search = search or ''
|
|
2194
|
+
return kaggle.kernels.kernels_api_client.list_kernels(request).kernels
|
|
2195
|
+
|
|
2196
|
+
kernels_list_result = self.process_response(
|
|
2197
|
+
self.kernels_list_with_http_info(
|
|
2198
|
+
page=page,
|
|
2199
|
+
page_size=page_size,
|
|
2200
|
+
group=group,
|
|
2201
|
+
user=user or '',
|
|
2202
|
+
language=language or 'all',
|
|
2203
|
+
kernel_type=kernel_type or 'all',
|
|
2204
|
+
output_type=output_type or 'all',
|
|
2205
|
+
sort_by=sort_by or 'hotness',
|
|
2206
|
+
dataset=dataset or '',
|
|
2207
|
+
competition=competition or '',
|
|
2208
|
+
parent_kernel=parent_kernel or '',
|
|
2209
|
+
search=search or ''))
|
|
2210
|
+
return [Kernel(k) for k in kernels_list_result]
|
|
2211
|
+
|
|
2212
|
+
def kernels_list_cli(self,
|
|
2213
|
+
mine=False,
|
|
2214
|
+
page=1,
|
|
2215
|
+
page_size=20,
|
|
2216
|
+
search=None,
|
|
2217
|
+
csv_display=False,
|
|
2218
|
+
parent=None,
|
|
2219
|
+
competition=None,
|
|
2220
|
+
dataset=None,
|
|
2221
|
+
user=None,
|
|
2222
|
+
language=None,
|
|
2223
|
+
kernel_type=None,
|
|
2224
|
+
output_type=None,
|
|
2225
|
+
sort_by=None):
|
|
2226
|
+
""" Client wrapper for kernels_list, see this function for arguments.
|
|
2227
|
+
Additional arguments are provided here.
|
|
2228
|
+
Parameters
|
|
2229
|
+
==========
|
|
2230
|
+
csv_display: if True, print comma separated values instead of table
|
|
2231
|
+
"""
|
|
2232
|
+
kernels = self.kernels_list(
|
|
2233
|
+
page=page,
|
|
2234
|
+
page_size=page_size,
|
|
2235
|
+
search=search,
|
|
2236
|
+
mine=mine,
|
|
2237
|
+
dataset=dataset,
|
|
2238
|
+
competition=competition,
|
|
2239
|
+
parent_kernel=parent,
|
|
2240
|
+
user=user,
|
|
2241
|
+
language=language,
|
|
2242
|
+
kernel_type=kernel_type,
|
|
2243
|
+
output_type=output_type,
|
|
2244
|
+
sort_by=sort_by)
|
|
2245
|
+
fields = ['ref', 'title', 'author', 'lastRunTime', 'totalVotes']
|
|
2246
|
+
if kernels:
|
|
2247
|
+
if csv_display:
|
|
2248
|
+
self.print_csv(kernels, fields)
|
|
2249
|
+
else:
|
|
2250
|
+
self.print_table(kernels, fields)
|
|
2251
|
+
else:
|
|
2252
|
+
print('Not found')
|
|
2253
|
+
|
|
2254
|
+
def kernels_list_files(self, kernel, page_token=None, page_size=20):
|
|
2255
|
+
""" list files for a kernel
|
|
2256
|
+
Parameters
|
|
2257
|
+
==========
|
|
2258
|
+
kernel: the string identifier of the kernel
|
|
2259
|
+
should be in format [owner]/[kernel-name]
|
|
2260
|
+
page_token: the page token for pagination
|
|
2261
|
+
page_size: the number of items per page
|
|
2262
|
+
"""
|
|
2263
|
+
if kernel is None:
|
|
2264
|
+
raise ValueError('A kernel must be specified')
|
|
2265
|
+
user_name, kernel_slug, kernel_version_number = self.split_dataset_string(
|
|
2266
|
+
kernel)
|
|
2267
|
+
|
|
2268
|
+
with self.build_kaggle_client() as kaggle:
|
|
2269
|
+
request = ApiListKernelFilesRequest()
|
|
2270
|
+
request.kernel_slug = kernel_slug
|
|
2271
|
+
request.user_name = user_name
|
|
2272
|
+
request.page_token = page_token
|
|
2273
|
+
request.page_size = page_size
|
|
2274
|
+
return kaggle.kernels.kernels_api_client.list_kernel_files(request)
|
|
2275
|
+
|
|
2276
|
+
def kernels_list_files_cli(self,
|
|
2277
|
+
kernel,
|
|
2278
|
+
kernel_opt=None,
|
|
2279
|
+
csv_display=False,
|
|
2280
|
+
page_token=None,
|
|
2281
|
+
page_size=20):
|
|
2282
|
+
""" A wrapper to kernel_list_files for the client.
|
|
2283
|
+
(list files for a kernel)
|
|
2284
|
+
Parameters
|
|
2285
|
+
==========
|
|
2286
|
+
kernel: the string identifier of the kernel
|
|
2287
|
+
should be in format [owner]/[kernel-name]
|
|
2288
|
+
kernel_opt: an alternative option to providing a kernel
|
|
2289
|
+
csv_display: if True, print comma separated values instead of table
|
|
2290
|
+
page_token: the page token for pagination
|
|
2291
|
+
page_size: the number of items per page
|
|
2292
|
+
"""
|
|
2293
|
+
kernel = kernel or kernel_opt
|
|
2294
|
+
result = self.kernels_list_files(kernel, page_token, page_size)
|
|
2295
|
+
|
|
2296
|
+
if result is None:
|
|
2297
|
+
print('No files found')
|
|
2298
|
+
return
|
|
2299
|
+
|
|
2300
|
+
if result.error_message:
|
|
2301
|
+
print(result.error_message)
|
|
2302
|
+
return
|
|
2303
|
+
|
|
2304
|
+
next_page_token = result.nextPageToken
|
|
2305
|
+
if next_page_token:
|
|
2306
|
+
print('Next Page Token = {}'.format(next_page_token))
|
|
2307
|
+
fields = ['name', 'size', 'creationDate']
|
|
2308
|
+
if csv_display:
|
|
2309
|
+
self.print_csv(result.files, fields)
|
|
2310
|
+
else:
|
|
2311
|
+
self.print_table(result.files, fields)
|
|
2312
|
+
|
|
2313
|
+
def kernels_initialize(self, folder):
|
|
2314
|
+
""" Create a new kernel in a specified folder from a template, including
|
|
2315
|
+
json metadata that grabs values from the configuration.
|
|
2316
|
+
Parameters
|
|
2317
|
+
==========
|
|
2318
|
+
folder: the path of the folder
|
|
2319
|
+
"""
|
|
2320
|
+
if not os.path.isdir(folder):
|
|
2321
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
2322
|
+
|
|
2323
|
+
resources = []
|
|
2324
|
+
resource = {'path': 'INSERT_SCRIPT_PATH_HERE'}
|
|
2325
|
+
resources.append(resource)
|
|
2326
|
+
|
|
2327
|
+
username = self.get_config_value(self.CONFIG_NAME_USER)
|
|
2328
|
+
meta_data = {
|
|
2329
|
+
'id':
|
|
2330
|
+
username + '/INSERT_KERNEL_SLUG_HERE',
|
|
2331
|
+
'title':
|
|
2332
|
+
'INSERT_TITLE_HERE',
|
|
2333
|
+
'code_file':
|
|
2334
|
+
'INSERT_CODE_FILE_PATH_HERE',
|
|
2335
|
+
'language':
|
|
2336
|
+
'Pick one of: {' +
|
|
2337
|
+
','.join(x for x in self.valid_push_language_types) + '}',
|
|
2338
|
+
'kernel_type':
|
|
2339
|
+
'Pick one of: {' +
|
|
2340
|
+
','.join(x for x in self.valid_push_kernel_types) + '}',
|
|
2341
|
+
'is_private':
|
|
2342
|
+
'true',
|
|
2343
|
+
'enable_gpu':
|
|
2344
|
+
'false',
|
|
2345
|
+
'enable_tpu':
|
|
2346
|
+
'false',
|
|
2347
|
+
'enable_internet':
|
|
2348
|
+
'true',
|
|
2349
|
+
'dataset_sources': [],
|
|
2350
|
+
'competition_sources': [],
|
|
2351
|
+
'kernel_sources': [],
|
|
2352
|
+
'model_sources': [],
|
|
2353
|
+
}
|
|
2354
|
+
meta_file = os.path.join(folder, self.KERNEL_METADATA_FILE)
|
|
2355
|
+
with open(meta_file, 'w') as f:
|
|
2356
|
+
json.dump(meta_data, f, indent=2)
|
|
2357
|
+
|
|
2358
|
+
return meta_file
|
|
2359
|
+
|
|
2360
|
+
def kernels_initialize_cli(self, folder=None):
|
|
2361
|
+
""" A client wrapper for kernels_initialize. It takes same arguments but
|
|
2362
|
+
sets default folder to be None. If None, defaults to present
|
|
2363
|
+
working directory.
|
|
2364
|
+
Parameters
|
|
2365
|
+
==========
|
|
2366
|
+
folder: the path of the folder (None defaults to ${PWD})
|
|
2367
|
+
"""
|
|
2368
|
+
folder = folder or os.getcwd()
|
|
2369
|
+
meta_file = self.kernels_initialize(folder)
|
|
2370
|
+
print('Kernel metadata template written to: ' + meta_file)
|
|
2371
|
+
|
|
2372
|
+
def kernels_push(self, folder, timeout):
|
|
2373
|
+
""" Read the metadata file and kernel files from a notebook, validate
|
|
2374
|
+
both, and use the Kernel API to push to Kaggle if all is valid.
|
|
2375
|
+
Parameters
|
|
2376
|
+
==========
|
|
2377
|
+
folder: the path of the folder
|
|
2378
|
+
"""
|
|
2379
|
+
if not os.path.isdir(folder):
|
|
2380
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
2381
|
+
|
|
2382
|
+
meta_file = os.path.join(folder, self.KERNEL_METADATA_FILE)
|
|
2383
|
+
if not os.path.isfile(meta_file):
|
|
2384
|
+
raise ValueError('Metadata file not found: ' + str(meta_file))
|
|
2385
|
+
|
|
2386
|
+
with open(meta_file) as f:
|
|
2387
|
+
meta_data = json.load(f)
|
|
2388
|
+
|
|
2389
|
+
title = self.get_or_default(meta_data, 'title', None)
|
|
2390
|
+
if title and len(title) < 5:
|
|
2391
|
+
raise ValueError('Title must be at least five characters')
|
|
2392
|
+
|
|
2393
|
+
code_path = self.get_or_default(meta_data, 'code_file', '')
|
|
2394
|
+
if not code_path:
|
|
2395
|
+
raise ValueError('A source file must be specified in the metadata')
|
|
2396
|
+
|
|
2397
|
+
code_file = os.path.join(folder, code_path)
|
|
2398
|
+
if not os.path.isfile(code_file):
|
|
2399
|
+
raise ValueError('Source file not found: ' + str(code_file))
|
|
2400
|
+
|
|
2401
|
+
slug = meta_data.get('id')
|
|
2402
|
+
id_no = meta_data.get('id_no')
|
|
2403
|
+
if not slug and not id_no:
|
|
2404
|
+
raise ValueError('ID or slug must be specified in the metadata')
|
|
2405
|
+
if slug:
|
|
2406
|
+
self.validate_kernel_string(slug)
|
|
2407
|
+
if '/' in slug:
|
|
2408
|
+
kernel_slug = slug.split('/')[1]
|
|
2409
|
+
else:
|
|
2410
|
+
kernel_slug = slug
|
|
2411
|
+
if title:
|
|
2412
|
+
as_slug = slugify(title)
|
|
2413
|
+
if kernel_slug.lower() != as_slug:
|
|
2414
|
+
print('Your kernel title does not resolve to the specified '
|
|
2415
|
+
'id. This may result in surprising behavior. We '
|
|
2416
|
+
'suggest making your title something that resolves to '
|
|
2417
|
+
'the specified id. See %s for more information on '
|
|
2418
|
+
'how slugs are determined.' %
|
|
2419
|
+
'https://en.wikipedia.org/wiki/Clean_URL#Slug')
|
|
2420
|
+
|
|
2421
|
+
language = self.get_or_default(meta_data, 'language', '')
|
|
2422
|
+
if language not in self.valid_push_language_types:
|
|
2423
|
+
raise ValueError(
|
|
2424
|
+
'A valid language must be specified in the metadata. Valid '
|
|
2425
|
+
'options are ' + str(self.valid_push_language_types))
|
|
2426
|
+
|
|
2427
|
+
kernel_type = self.get_or_default(meta_data, 'kernel_type', '')
|
|
2428
|
+
if kernel_type not in self.valid_push_kernel_types:
|
|
2429
|
+
raise ValueError(
|
|
2430
|
+
'A valid kernel type must be specified in the metadata. Valid '
|
|
2431
|
+
'options are ' + str(self.valid_push_kernel_types))
|
|
2432
|
+
|
|
2433
|
+
if kernel_type == 'notebook' and language == 'rmarkdown':
|
|
2434
|
+
language = 'r'
|
|
2435
|
+
|
|
2436
|
+
dataset_sources = self.get_or_default(meta_data, 'dataset_sources', [])
|
|
2437
|
+
for source in dataset_sources:
|
|
2438
|
+
self.validate_dataset_string(source)
|
|
2439
|
+
|
|
2440
|
+
kernel_sources = self.get_or_default(meta_data, 'kernel_sources', [])
|
|
2441
|
+
for source in kernel_sources:
|
|
2442
|
+
self.validate_kernel_string(source)
|
|
2443
|
+
|
|
2444
|
+
model_sources = self.get_or_default(meta_data, 'model_sources', [])
|
|
2445
|
+
for source in model_sources:
|
|
2446
|
+
self.validate_model_string(source)
|
|
2447
|
+
|
|
2448
|
+
docker_pinning_type = self.get_or_default(meta_data,
|
|
2449
|
+
'docker_image_pinning_type', None)
|
|
2450
|
+
if (docker_pinning_type is not None and
|
|
2451
|
+
docker_pinning_type not in self.valid_push_pinning_types):
|
|
2452
|
+
raise ValueError('If specified, the docker_image_pinning_type must be '
|
|
2453
|
+
'one of ' + str(self.valid_push_pinning_types))
|
|
2454
|
+
|
|
2455
|
+
with open(code_file) as f:
|
|
2456
|
+
script_body = f.read()
|
|
2457
|
+
|
|
2458
|
+
if kernel_type == 'notebook':
|
|
2459
|
+
json_body = json.loads(script_body)
|
|
2460
|
+
if 'cells' in json_body:
|
|
2461
|
+
for cell in json_body['cells']:
|
|
2462
|
+
if 'outputs' in cell and cell['cell_type'] == 'code':
|
|
2463
|
+
cell['outputs'] = []
|
|
2464
|
+
# The spec allows a list of strings,
|
|
2465
|
+
# but the server expects just one
|
|
2466
|
+
if 'source' in cell and isinstance(cell['source'], list):
|
|
2467
|
+
cell['source'] = ''.join(cell['source'])
|
|
2468
|
+
script_body = json.dumps(json_body).replace("'", "\\'")
|
|
2469
|
+
|
|
2470
|
+
with self.build_kaggle_client() as kaggle:
|
|
2471
|
+
request = ApiSaveKernelRequest()
|
|
2472
|
+
request.id = id_no
|
|
2473
|
+
request.slug = slug
|
|
2474
|
+
request.new_title = self.get_or_default(meta_data, 'title', None)
|
|
2475
|
+
request.text = script_body
|
|
2476
|
+
request.language = language
|
|
2477
|
+
request.kernel_type = kernel_type
|
|
2478
|
+
request.is_private = self.get_bool(meta_data, 'is_private', True)
|
|
2479
|
+
request.enable_gpu = self.get_bool(meta_data, 'enable_gpu', False)
|
|
2480
|
+
request.enable_tpu = self.get_bool(meta_data, 'enable_tpu', False)
|
|
2481
|
+
request.enable_internet = self.get_bool(meta_data, 'enable_internet',
|
|
2482
|
+
True)
|
|
2483
|
+
request.dataset_data_sources = dataset_sources
|
|
2484
|
+
request.competition_data_sources = self.get_or_default(
|
|
2485
|
+
meta_data, 'competition_sources', [])
|
|
2486
|
+
request.kernel_data_sources = kernel_sources
|
|
2487
|
+
request.model_data_sources = model_sources
|
|
2488
|
+
request.category_ids = self.get_or_default(meta_data, 'keywords', [])
|
|
2489
|
+
request.docker_image_pinning_type = docker_pinning_type
|
|
2490
|
+
if timeout:
|
|
2491
|
+
request.session_timeout_seconds = int(timeout)
|
|
2492
|
+
return kaggle.kernels.kernels_api_client.save_kernel(request)
|
|
2493
|
+
|
|
2494
|
+
def kernels_push_cli(self, folder, timeout):
|
|
2495
|
+
""" Client wrapper for kernels_push.
|
|
2496
|
+
Parameters
|
|
2497
|
+
==========
|
|
2498
|
+
folder: the path of the folder
|
|
2499
|
+
"""
|
|
2500
|
+
folder = folder or os.getcwd()
|
|
2501
|
+
result = self.kernels_push(folder, timeout)
|
|
2502
|
+
|
|
2503
|
+
if result is None:
|
|
2504
|
+
print('Kernel push error: see previous output')
|
|
2505
|
+
elif not result.error:
|
|
2506
|
+
if result.invalidTags:
|
|
2507
|
+
print('The following are not valid tags and could not be added '
|
|
2508
|
+
'to the kernel: ' + str(result.invalidTags))
|
|
2509
|
+
if result.invalidDatasetSources:
|
|
2510
|
+
print('The following are not valid dataset sources and could not '
|
|
2511
|
+
'be added to the kernel: ' + str(result.invalidDatasetSources))
|
|
2512
|
+
if result.invalidCompetitionSources:
|
|
2513
|
+
print('The following are not valid competition sources and could '
|
|
2514
|
+
'not be added to the kernel: ' +
|
|
2515
|
+
str(result.invalidCompetitionSources))
|
|
2516
|
+
if result.invalidKernelSources:
|
|
2517
|
+
print('The following are not valid kernel sources and could not '
|
|
2518
|
+
'be added to the kernel: ' + str(result.invalidKernelSources))
|
|
2519
|
+
|
|
2520
|
+
if result.versionNumber:
|
|
2521
|
+
print('Kernel version %s successfully pushed. Please check '
|
|
2522
|
+
'progress at %s' % (result.versionNumber, result.url))
|
|
2523
|
+
else:
|
|
2524
|
+
# Shouldn't happen but didn't test exhaustively
|
|
2525
|
+
print('Kernel version successfully pushed. Please check '
|
|
2526
|
+
'progress at %s' % result.url)
|
|
2527
|
+
else:
|
|
2528
|
+
print('Kernel push error: ' + result.error)
|
|
2529
|
+
|
|
2530
|
+
def kernels_pull(self, kernel, path, metadata=False, quiet=True):
|
|
2531
|
+
""" Pull a kernel, including a metadata file (if metadata is True)
|
|
2532
|
+
and associated files to a specified path.
|
|
2533
|
+
Parameters
|
|
2534
|
+
==========
|
|
2535
|
+
kernel: the kernel to pull
|
|
2536
|
+
path: the path to pull files to on the filesystem
|
|
2537
|
+
metadata: if True, also pull metadata
|
|
2538
|
+
quiet: suppress verbosity (default is True)
|
|
2539
|
+
"""
|
|
2540
|
+
existing_metadata = None
|
|
2541
|
+
if kernel is None:
|
|
2542
|
+
if path is None:
|
|
2543
|
+
existing_metadata_path = os.path.join(os.getcwd(),
|
|
2544
|
+
self.KERNEL_METADATA_FILE)
|
|
2545
|
+
else:
|
|
2546
|
+
existing_metadata_path = os.path.join(path, self.KERNEL_METADATA_FILE)
|
|
2547
|
+
if os.path.exists(existing_metadata_path):
|
|
2548
|
+
with open(existing_metadata_path) as f:
|
|
2549
|
+
existing_metadata = json.load(f)
|
|
2550
|
+
kernel = existing_metadata['id']
|
|
2551
|
+
if 'INSERT_KERNEL_SLUG_HERE' in kernel:
|
|
2552
|
+
raise ValueError('A kernel must be specified')
|
|
2553
|
+
else:
|
|
2554
|
+
print('Using kernel ' + kernel)
|
|
2555
|
+
|
|
2556
|
+
if '/' in kernel:
|
|
2557
|
+
self.validate_kernel_string(kernel)
|
|
2558
|
+
kernel_url_list = kernel.split('/')
|
|
2559
|
+
owner_slug = kernel_url_list[0]
|
|
2560
|
+
kernel_slug = kernel_url_list[1]
|
|
2561
|
+
else:
|
|
2562
|
+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
|
|
2563
|
+
kernel_slug = kernel
|
|
2564
|
+
|
|
2565
|
+
if path is None:
|
|
2566
|
+
effective_path = self.get_default_download_dir('kernels', owner_slug,
|
|
2567
|
+
kernel_slug)
|
|
2568
|
+
else:
|
|
2569
|
+
effective_path = path
|
|
2570
|
+
|
|
2571
|
+
if not os.path.exists(effective_path):
|
|
2572
|
+
os.makedirs(effective_path)
|
|
2573
|
+
|
|
2574
|
+
with self.build_kaggle_client() as kaggle:
|
|
2575
|
+
request = ApiGetKernelRequest()
|
|
2576
|
+
request.user_name = owner_slug
|
|
2577
|
+
request.kernel_slug = kernel_slug
|
|
2578
|
+
response = kaggle.kernels.kernels_api_client.get_kernel(request)
|
|
2579
|
+
|
|
2580
|
+
blob = response.blob
|
|
2581
|
+
|
|
2582
|
+
if os.path.isfile(effective_path):
|
|
2583
|
+
effective_dir = os.path.dirname(effective_path)
|
|
2584
|
+
else:
|
|
2585
|
+
effective_dir = effective_path
|
|
2586
|
+
metadata_path = os.path.join(effective_dir, self.KERNEL_METADATA_FILE)
|
|
2587
|
+
|
|
2588
|
+
if not os.path.isfile(effective_path):
|
|
2589
|
+
language = blob.language.lower()
|
|
2590
|
+
kernel_type = blob.kernel_type.lower()
|
|
2591
|
+
|
|
2592
|
+
file_name = None
|
|
2593
|
+
if existing_metadata:
|
|
2594
|
+
file_name = existing_metadata['code_file']
|
|
2595
|
+
elif os.path.isfile(metadata_path):
|
|
2596
|
+
with open(metadata_path) as f:
|
|
2597
|
+
file_name = json.load(f)['code_file']
|
|
2598
|
+
|
|
2599
|
+
if not file_name or file_name == "INSERT_CODE_FILE_PATH_HERE":
|
|
2600
|
+
extension = None
|
|
2601
|
+
if kernel_type == 'script':
|
|
2602
|
+
if language == 'python':
|
|
2603
|
+
extension = '.py'
|
|
2604
|
+
elif language == 'r':
|
|
2605
|
+
extension = '.R'
|
|
2606
|
+
elif language == 'rmarkdown':
|
|
2607
|
+
extension = '.Rmd'
|
|
2608
|
+
elif language == 'sqlite':
|
|
2609
|
+
extension = '.sql'
|
|
2610
|
+
elif language == 'julia':
|
|
2611
|
+
extension = '.jl'
|
|
2612
|
+
elif kernel_type == 'notebook':
|
|
2613
|
+
if language == 'python':
|
|
2614
|
+
extension = '.ipynb'
|
|
2615
|
+
elif language == 'r':
|
|
2616
|
+
extension = '.irnb'
|
|
2617
|
+
elif language == 'julia':
|
|
2618
|
+
extension = '.ijlnb'
|
|
2619
|
+
file_name = blob.slug + extension
|
|
2620
|
+
|
|
2621
|
+
if file_name is None:
|
|
2622
|
+
print('Unknown language %s + kernel type %s - please report this '
|
|
2623
|
+
'on the kaggle-api github issues' % (language, kernel_type))
|
|
2624
|
+
print('Saving as a python file, even though this may not be the '
|
|
2625
|
+
'correct language')
|
|
2626
|
+
file_name = 'script.py'
|
|
2627
|
+
script_path = os.path.join(effective_path, file_name)
|
|
2628
|
+
else:
|
|
2629
|
+
script_path = effective_path
|
|
2630
|
+
file_name = os.path.basename(effective_path)
|
|
2631
|
+
|
|
2632
|
+
with open(script_path, 'w', encoding="utf-8") as f:
|
|
2633
|
+
f.write(blob.source)
|
|
2634
|
+
|
|
2635
|
+
if metadata:
|
|
2636
|
+
data = {}
|
|
2637
|
+
server_metadata = response.metadata
|
|
2638
|
+
data['id'] = server_metadata.ref
|
|
2639
|
+
data['id_no'] = server_metadata.id
|
|
2640
|
+
data['title'] = server_metadata.title
|
|
2641
|
+
data['code_file'] = file_name
|
|
2642
|
+
data['language'] = server_metadata.language
|
|
2643
|
+
data['kernel_type'] = server_metadata.kernel_type
|
|
2644
|
+
data['is_private'] = server_metadata.is_private
|
|
2645
|
+
data['enable_gpu'] = server_metadata.enable_gpu
|
|
2646
|
+
data['enable_tpu'] = server_metadata.enable_tpu
|
|
2647
|
+
data['enable_internet'] = server_metadata.enable_internet
|
|
2648
|
+
data['keywords'] = server_metadata.category_ids
|
|
2649
|
+
data['dataset_sources'] = server_metadata.dataset_data_sources
|
|
2650
|
+
data['kernel_sources'] = server_metadata.kernel_data_sources
|
|
2651
|
+
data['competition_sources'] = server_metadata.competition_data_sources
|
|
2652
|
+
data['model_sources'] = server_metadata.model_data_sources
|
|
2653
|
+
with open(metadata_path, 'w') as f:
|
|
2654
|
+
json.dump(data, f, indent=2)
|
|
2655
|
+
|
|
2656
|
+
return effective_dir
|
|
2657
|
+
else:
|
|
2658
|
+
return script_path
|
|
2659
|
+
|
|
2660
|
+
def kernels_pull_cli(self,
|
|
2661
|
+
kernel,
|
|
2662
|
+
kernel_opt=None,
|
|
2663
|
+
path=None,
|
|
2664
|
+
metadata=False):
|
|
2665
|
+
""" Client wrapper for kernels_pull.
|
|
2666
|
+
"""
|
|
2667
|
+
kernel = kernel or kernel_opt
|
|
2668
|
+
effective_path = self.kernels_pull(
|
|
2669
|
+
kernel, path=path, metadata=metadata, quiet=False)
|
|
2670
|
+
if metadata:
|
|
2671
|
+
print('Source code and metadata downloaded to ' + effective_path)
|
|
2672
|
+
else:
|
|
2673
|
+
print('Source code downloaded to ' + effective_path)
|
|
2674
|
+
|
|
2675
|
+
def kernels_output(self, kernel, path, force=False, quiet=True):
|
|
2676
|
+
""" Retrieve the output for a specified kernel.
|
|
2677
|
+
Parameters
|
|
2678
|
+
==========
|
|
2679
|
+
kernel: the kernel to output
|
|
2680
|
+
path: the path to pull files to on the filesystem
|
|
2681
|
+
force: if output already exists, force overwrite (default False)
|
|
2682
|
+
quiet: suppress verbosity (default is True)
|
|
2683
|
+
"""
|
|
2684
|
+
if kernel is None:
|
|
2685
|
+
raise ValueError('A kernel must be specified')
|
|
2686
|
+
if '/' in kernel:
|
|
2687
|
+
self.validate_kernel_string(kernel)
|
|
2688
|
+
kernel_url_list = kernel.split('/')
|
|
2689
|
+
owner_slug = kernel_url_list[0]
|
|
2690
|
+
kernel_slug = kernel_url_list[1]
|
|
2691
|
+
else:
|
|
2692
|
+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
|
|
2693
|
+
kernel_slug = kernel
|
|
2694
|
+
|
|
2695
|
+
if path is None:
|
|
2696
|
+
target_dir = self.get_default_download_dir('kernels', owner_slug,
|
|
2697
|
+
kernel_slug, 'output')
|
|
2698
|
+
else:
|
|
2699
|
+
target_dir = path
|
|
2700
|
+
|
|
2701
|
+
if not os.path.exists(target_dir):
|
|
2702
|
+
os.makedirs(target_dir)
|
|
2703
|
+
|
|
2704
|
+
if not os.path.isdir(target_dir):
|
|
2705
|
+
raise ValueError('You must specify a directory for the kernels output')
|
|
2706
|
+
|
|
2707
|
+
token = None
|
|
2708
|
+
with self.build_kaggle_client() as kaggle:
|
|
2709
|
+
request = ApiListKernelSessionOutputRequest()
|
|
2710
|
+
request.user_name = owner_slug
|
|
2711
|
+
request.kernel_slug = kernel_slug
|
|
2712
|
+
response = kaggle.kernels.kernels_api_client.list_kernel_session_output(
|
|
2713
|
+
request)
|
|
2714
|
+
token = response.next_page_token
|
|
2715
|
+
|
|
2716
|
+
outfiles = []
|
|
2717
|
+
for item in response.files:
|
|
2718
|
+
outfile = os.path.join(target_dir, item['fileName'])
|
|
2719
|
+
outfiles.append(outfile)
|
|
2720
|
+
download_response = requests.get(item['url'], stream=True)
|
|
2721
|
+
if force or self.download_needed(download_response, outfile, quiet):
|
|
2722
|
+
os.makedirs(os.path.split(outfile)[0], exist_ok=True)
|
|
2723
|
+
with open(outfile, 'wb') as out:
|
|
2724
|
+
out.write(download_response.content)
|
|
2725
|
+
if not quiet:
|
|
2726
|
+
print('Output file downloaded to %s' % outfile)
|
|
2727
|
+
|
|
2728
|
+
log = response.log
|
|
2729
|
+
if log:
|
|
2730
|
+
outfile = os.path.join(target_dir, kernel_slug + '.log')
|
|
2731
|
+
outfiles.append(outfile)
|
|
2732
|
+
with open(outfile, 'w') as out:
|
|
2733
|
+
out.write(log)
|
|
2734
|
+
if not quiet:
|
|
2735
|
+
print('Kernel log downloaded to %s ' % outfile)
|
|
2736
|
+
|
|
2737
|
+
return outfiles, token # Breaking change, we need to get the token to the UI
|
|
2738
|
+
|
|
2739
|
+
def kernels_output_cli(self,
|
|
2740
|
+
kernel,
|
|
2741
|
+
kernel_opt=None,
|
|
2742
|
+
path=None,
|
|
2743
|
+
force=False,
|
|
2744
|
+
quiet=False):
|
|
2745
|
+
""" Client wrapper for kernels_output, with same arguments. Extra
|
|
2746
|
+
arguments are described below, and see kernels_output for others.
|
|
2747
|
+
Parameters
|
|
2748
|
+
==========
|
|
2749
|
+
kernel_opt: option from client instead of kernel, if not defined
|
|
2750
|
+
"""
|
|
2751
|
+
kernel = kernel or kernel_opt
|
|
2752
|
+
(_, token) = self.kernels_output(kernel, path, force, quiet)
|
|
2753
|
+
if token:
|
|
2754
|
+
print(f"Next page token: {token}")
|
|
2755
|
+
|
|
2756
|
+
def kernels_status(self, kernel):
|
|
2757
|
+
""" Call to the api to get the status of a kernel.
|
|
2758
|
+
Parameters
|
|
2759
|
+
==========
|
|
2760
|
+
kernel: the kernel to get the status for
|
|
2761
|
+
"""
|
|
2762
|
+
if kernel is None:
|
|
2763
|
+
raise ValueError('A kernel must be specified')
|
|
2764
|
+
if '/' in kernel:
|
|
2765
|
+
self.validate_kernel_string(kernel)
|
|
2766
|
+
kernel_url_list = kernel.split('/')
|
|
2767
|
+
owner_slug = kernel_url_list[0]
|
|
2768
|
+
kernel_slug = kernel_url_list[1]
|
|
2769
|
+
else:
|
|
2770
|
+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
|
|
2771
|
+
kernel_slug = kernel
|
|
2772
|
+
with self.build_kaggle_client() as kaggle:
|
|
2773
|
+
request = ApiGetKernelSessionStatusRequest()
|
|
2774
|
+
request.user_name = owner_slug
|
|
2775
|
+
request.kernel_slug = kernel_slug
|
|
2776
|
+
return kaggle.kernels.kernels_api_client.get_kernel_session_status(
|
|
2777
|
+
request)
|
|
2778
|
+
|
|
2779
|
+
def kernels_status_cli(self, kernel, kernel_opt=None):
|
|
2780
|
+
""" Client wrapper for kernel_status.
|
|
2781
|
+
Parameters
|
|
2782
|
+
==========
|
|
2783
|
+
kernel_opt: additional option from the client, if kernel not defined
|
|
2784
|
+
"""
|
|
2785
|
+
kernel = kernel or kernel_opt
|
|
2786
|
+
response = self.kernels_status(kernel)
|
|
2787
|
+
status = response.status
|
|
2788
|
+
message = response.failure_message
|
|
2789
|
+
if message:
|
|
2790
|
+
print('%s has status "%s"' % (kernel, status))
|
|
2791
|
+
print('Failure message: "%s"' % message)
|
|
2792
|
+
else:
|
|
2793
|
+
print('%s has status "%s"' % (kernel, status))
|
|
2794
|
+
|
|
2795
|
+
def model_get(self, model):
|
|
2796
|
+
""" Get a model.
|
|
2797
|
+
Parameters
|
|
2798
|
+
==========
|
|
2799
|
+
model: the string identifier of the model
|
|
2800
|
+
should be in format [owner]/[model-name]
|
|
2801
|
+
"""
|
|
2802
|
+
owner_slug, model_slug = self.split_model_string(model)
|
|
2803
|
+
|
|
2804
|
+
with self.build_kaggle_client() as kaggle:
|
|
2805
|
+
request = ApiGetModelRequest()
|
|
2806
|
+
request.owner_slug = owner_slug
|
|
2807
|
+
request.model_slug = model_slug
|
|
2808
|
+
return kaggle.models.model_api_client.get_model(request)
|
|
2809
|
+
|
|
2810
|
+
def model_get_cli(self, model, folder=None):
|
|
2811
|
+
""" Clent wrapper for model_get, with additional
|
|
2812
|
+
model_opt to get a model from the API.
|
|
2813
|
+
Parameters
|
|
2814
|
+
==========
|
|
2815
|
+
model: the string identifier of the model
|
|
2816
|
+
should be in format [owner]/[model-name]
|
|
2817
|
+
folder: the folder to download the model metadata file
|
|
2818
|
+
"""
|
|
2819
|
+
model = self.model_get(model)
|
|
2820
|
+
if folder is None:
|
|
2821
|
+
self.print_obj(model)
|
|
2822
|
+
else:
|
|
2823
|
+
meta_file = os.path.join(folder, self.MODEL_METADATA_FILE)
|
|
2824
|
+
|
|
2825
|
+
data = {}
|
|
2826
|
+
data['id'] = model.id
|
|
2827
|
+
model_ref_split = model.ref.split('/')
|
|
2828
|
+
data['ownerSlug'] = model_ref_split[0]
|
|
2829
|
+
data['slug'] = model_ref_split[1]
|
|
2830
|
+
data['title'] = model.title
|
|
2831
|
+
data['subtitle'] = model.subtitle
|
|
2832
|
+
data['isPrivate'] = model.isPrivate # TODO Test to ensure True default
|
|
2833
|
+
data['description'] = model.description
|
|
2834
|
+
data['publishTime'] = model.publishTime
|
|
2835
|
+
|
|
2836
|
+
with open(meta_file, 'w') as f:
|
|
2837
|
+
json.dump(data, f, indent=2)
|
|
2838
|
+
print('Metadata file written to {}'.format(meta_file))
|
|
2839
|
+
|
|
2840
|
+
def model_list(self,
|
|
2841
|
+
sort_by=None,
|
|
2842
|
+
search=None,
|
|
2843
|
+
owner=None,
|
|
2844
|
+
page_size=20,
|
|
2845
|
+
page_token=None):
|
|
2846
|
+
""" Return a list of models.
|
|
2847
|
+
|
|
2848
|
+
Parameters
|
|
2849
|
+
==========
|
|
2850
|
+
sort_by: how to sort the result, see valid_model_sort_bys for options
|
|
2851
|
+
search: a search term to use (default is empty string)
|
|
2852
|
+
owner: username or organization slug to filter the search to
|
|
2853
|
+
page_size: the page size to return (default is 20)
|
|
2854
|
+
page_token: the page token for pagination
|
|
2855
|
+
"""
|
|
2856
|
+
if sort_by:
|
|
2857
|
+
if sort_by not in self.valid_model_sort_bys:
|
|
2858
|
+
raise ValueError('Invalid sort by specified. Valid options are ' +
|
|
2859
|
+
str(self.valid_model_sort_bys))
|
|
2860
|
+
sort_by = self.lookup_enum(ListModelsOrderBy, sort_by)
|
|
2861
|
+
|
|
2862
|
+
if int(page_size) <= 0:
|
|
2863
|
+
raise ValueError('Page size must be >= 1')
|
|
2864
|
+
|
|
2865
|
+
with self.build_kaggle_client() as kaggle:
|
|
2866
|
+
request = ApiListModelsRequest()
|
|
2867
|
+
request.sort_by = sort_by or ListModelsOrderBy.LIST_MODELS_ORDER_BY_HOTNESS
|
|
2868
|
+
request.search = search or ''
|
|
2869
|
+
request.owner = owner or ''
|
|
2870
|
+
request.page_size = page_size
|
|
2871
|
+
request.page_token = page_token
|
|
2872
|
+
response = kaggle.models.model_api_client.list_models(request)
|
|
2873
|
+
if response.next_page_token:
|
|
2874
|
+
print('Next Page Token = {}'.format(response.next_page_token))
|
|
2875
|
+
return response.models
|
|
2876
|
+
|
|
2877
|
+
def model_list_cli(self,
|
|
2878
|
+
sort_by=None,
|
|
2879
|
+
search=None,
|
|
2880
|
+
owner=None,
|
|
2881
|
+
page_size=20,
|
|
2882
|
+
page_token=None,
|
|
2883
|
+
csv_display=False):
|
|
2884
|
+
""" Client wrapper for model_list. Additional parameters
|
|
2885
|
+
are described here, see model_list for others.
|
|
2886
|
+
|
|
2887
|
+
Parameters
|
|
2888
|
+
==========
|
|
2889
|
+
sort_by: how to sort the result, see valid_model_sort_bys for options
|
|
2890
|
+
search: a search term to use (default is empty string)
|
|
2891
|
+
owner: username or organization slug to filter the search to
|
|
2892
|
+
page_size: the page size to return (default is 20)
|
|
2893
|
+
page_token: the page token for pagination
|
|
2894
|
+
csv_display: if True, print comma separated values instead of table
|
|
2895
|
+
"""
|
|
2896
|
+
models = self.model_list(sort_by, search, owner, page_size, page_token)
|
|
2897
|
+
fields = ['id', 'ref', 'title', 'subtitle', 'author']
|
|
2898
|
+
if models:
|
|
2899
|
+
if csv_display:
|
|
2900
|
+
self.print_csv(models, fields)
|
|
2901
|
+
else:
|
|
2902
|
+
self.print_table(models, fields)
|
|
2903
|
+
else:
|
|
2904
|
+
print('No models found')
|
|
2905
|
+
|
|
2906
|
+
def model_initialize(self, folder):
|
|
2907
|
+
""" Initialize a folder with a model configuration (metadata) file.
|
|
2908
|
+
Parameters
|
|
2909
|
+
==========
|
|
2910
|
+
folder: the folder to initialize the metadata file in
|
|
2911
|
+
"""
|
|
2912
|
+
if not os.path.isdir(folder):
|
|
2913
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
2914
|
+
|
|
2915
|
+
meta_data = {
|
|
2916
|
+
'ownerSlug':
|
|
2917
|
+
'INSERT_OWNER_SLUG_HERE',
|
|
2918
|
+
'title':
|
|
2919
|
+
'INSERT_TITLE_HERE',
|
|
2920
|
+
'slug':
|
|
2921
|
+
'INSERT_SLUG_HERE',
|
|
2922
|
+
'subtitle':
|
|
2923
|
+
'',
|
|
2924
|
+
'isPrivate':
|
|
2925
|
+
True,
|
|
2926
|
+
'description':
|
|
2927
|
+
'''# Model Summary
|
|
2928
|
+
|
|
2929
|
+
# Model Characteristics
|
|
2930
|
+
|
|
2931
|
+
# Data Overview
|
|
2932
|
+
|
|
2933
|
+
# Evaluation Results
|
|
2934
|
+
''',
|
|
2935
|
+
'publishTime':
|
|
2936
|
+
'',
|
|
2937
|
+
'provenanceSources':
|
|
2938
|
+
''
|
|
2939
|
+
}
|
|
2940
|
+
meta_file = os.path.join(folder, self.MODEL_METADATA_FILE)
|
|
2941
|
+
with open(meta_file, 'w') as f:
|
|
2942
|
+
json.dump(meta_data, f, indent=2)
|
|
2943
|
+
|
|
2944
|
+
print('Model template written to: ' + meta_file)
|
|
2945
|
+
return meta_file
|
|
2946
|
+
|
|
2947
|
+
def model_initialize_cli(self, folder=None):
|
|
2948
|
+
folder = folder or os.getcwd()
|
|
2949
|
+
self.model_initialize(folder)
|
|
2950
|
+
|
|
2951
|
+
def model_create_new(self, folder):
|
|
2952
|
+
""" Create a new model.
|
|
2953
|
+
Parameters
|
|
2954
|
+
==========
|
|
2955
|
+
folder: the folder to get the metadata file from
|
|
2956
|
+
"""
|
|
2957
|
+
if not os.path.isdir(folder):
|
|
2958
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
2959
|
+
|
|
2960
|
+
meta_file = self.get_model_metadata_file(folder)
|
|
2961
|
+
|
|
2962
|
+
# read json
|
|
2963
|
+
with open(meta_file) as f:
|
|
2964
|
+
meta_data = json.load(f)
|
|
2965
|
+
owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
|
|
2966
|
+
slug = self.get_or_fail(meta_data, 'slug')
|
|
2967
|
+
title = self.get_or_fail(meta_data, 'title')
|
|
2968
|
+
subtitle = meta_data.get('subtitle')
|
|
2969
|
+
is_private = self.get_or_fail(meta_data, 'isPrivate')
|
|
2970
|
+
description = self.sanitize_markdown(
|
|
2971
|
+
self.get_or_fail(meta_data, 'description'))
|
|
2972
|
+
publish_time = meta_data.get('publishTime')
|
|
2973
|
+
provenance_sources = meta_data.get('provenanceSources')
|
|
2974
|
+
|
|
2975
|
+
# validations
|
|
2976
|
+
if owner_slug == 'INSERT_OWNER_SLUG_HERE':
|
|
2977
|
+
raise ValueError(
|
|
2978
|
+
'Default ownerSlug detected, please change values before uploading')
|
|
2979
|
+
if title == 'INSERT_TITLE_HERE':
|
|
2980
|
+
raise ValueError(
|
|
2981
|
+
'Default title detected, please change values before uploading')
|
|
2982
|
+
if slug == 'INSERT_SLUG_HERE':
|
|
2983
|
+
raise ValueError(
|
|
2984
|
+
'Default slug detected, please change values before uploading')
|
|
2985
|
+
if not isinstance(is_private, bool):
|
|
2986
|
+
raise ValueError('model.isPrivate must be a boolean')
|
|
2987
|
+
if publish_time:
|
|
2988
|
+
self.validate_date(publish_time)
|
|
2989
|
+
else:
|
|
2990
|
+
publish_time = None
|
|
2991
|
+
|
|
2992
|
+
with self.build_kaggle_client() as kaggle:
|
|
2993
|
+
request = ApiCreateModelRequest()
|
|
2994
|
+
request.owner_slug = owner_slug
|
|
2995
|
+
request.slug = slug
|
|
2996
|
+
request.title = title
|
|
2997
|
+
request.subtitle = subtitle
|
|
2998
|
+
request.is_private = is_private
|
|
2999
|
+
request.description = description
|
|
3000
|
+
request.publish_time = publish_time
|
|
3001
|
+
request.provenance_sources = provenance_sources
|
|
3002
|
+
return kaggle.models.model_api_client.create_model(request)
|
|
3003
|
+
|
|
3004
|
+
def model_create_new_cli(self, folder=None):
|
|
3005
|
+
""" Client wrapper for creating a new model.
|
|
3006
|
+
Parameters
|
|
3007
|
+
==========
|
|
3008
|
+
folder: the folder to get the metadata file from
|
|
3009
|
+
"""
|
|
3010
|
+
folder = folder or os.getcwd()
|
|
3011
|
+
result = self.model_create_new(folder)
|
|
3012
|
+
|
|
3013
|
+
if result.hasId:
|
|
3014
|
+
print('Your model was created. Id={}. Url={}'.format(
|
|
3015
|
+
result.id, result.url))
|
|
3016
|
+
else:
|
|
3017
|
+
print('Model creation error: ' + result.error)
|
|
3018
|
+
|
|
3019
|
+
def model_delete(self, model, yes):
|
|
3020
|
+
""" Delete a modeL.
|
|
3021
|
+
Parameters
|
|
3022
|
+
==========
|
|
3023
|
+
model: the string identifier of the model
|
|
3024
|
+
should be in format [owner]/[model-name]
|
|
3025
|
+
yes: automatic confirmation
|
|
3026
|
+
"""
|
|
3027
|
+
owner_slug, model_slug = self.split_model_string(model)
|
|
3028
|
+
|
|
3029
|
+
if not yes:
|
|
3030
|
+
if not self.confirmation():
|
|
3031
|
+
print('Deletion cancelled')
|
|
3032
|
+
exit(0)
|
|
3033
|
+
|
|
3034
|
+
with self.build_kaggle_client() as kaggle:
|
|
3035
|
+
request = ApiDeleteModelRequest()
|
|
3036
|
+
request.owner_slug = owner_slug
|
|
3037
|
+
request.model_slug = model_slug
|
|
3038
|
+
return kaggle.models.model_api_client.delete_model(request)
|
|
3039
|
+
|
|
3040
|
+
def model_delete_cli(self, model, yes):
|
|
3041
|
+
""" Client wrapper for deleting a model.
|
|
3042
|
+
Parameters
|
|
3043
|
+
==========
|
|
3044
|
+
model: the string identified of the model
|
|
3045
|
+
should be in format [owner]/[model-name]
|
|
3046
|
+
yes: automatic confirmation
|
|
3047
|
+
"""
|
|
3048
|
+
result = self.model_delete(model, yes)
|
|
3049
|
+
|
|
3050
|
+
if result.error:
|
|
3051
|
+
print('Model deletion error: ' + result.error)
|
|
3052
|
+
else:
|
|
3053
|
+
print('The model was deleted.')
|
|
3054
|
+
|
|
3055
|
+
def model_update(self, folder):
|
|
3056
|
+
""" Update a model.
|
|
3057
|
+
Parameters
|
|
3058
|
+
==========
|
|
3059
|
+
folder: the folder to get the metadata file from
|
|
3060
|
+
"""
|
|
3061
|
+
if not os.path.isdir(folder):
|
|
3062
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
3063
|
+
|
|
3064
|
+
meta_file = self.get_model_metadata_file(folder)
|
|
3065
|
+
|
|
3066
|
+
# read json
|
|
3067
|
+
with open(meta_file) as f:
|
|
3068
|
+
meta_data = json.load(f)
|
|
3069
|
+
owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
|
|
3070
|
+
slug = self.get_or_fail(meta_data, 'slug')
|
|
3071
|
+
title = self.get_or_default(meta_data, 'title', None)
|
|
3072
|
+
subtitle = self.get_or_default(meta_data, 'subtitle', None)
|
|
3073
|
+
is_private = self.get_or_default(meta_data, 'isPrivate', None)
|
|
3074
|
+
description = self.get_or_default(meta_data, 'description', None)
|
|
3075
|
+
publish_time = self.get_or_default(meta_data, 'publishTime', None)
|
|
3076
|
+
provenance_sources = self.get_or_default(meta_data, 'provenanceSources',
|
|
3077
|
+
None)
|
|
3078
|
+
|
|
3079
|
+
# validations
|
|
3080
|
+
if owner_slug == 'INSERT_OWNER_SLUG_HERE':
|
|
3081
|
+
raise ValueError(
|
|
3082
|
+
'Default ownerSlug detected, please change values before uploading')
|
|
3083
|
+
if slug == 'INSERT_SLUG_HERE':
|
|
3084
|
+
raise ValueError(
|
|
3085
|
+
'Default slug detected, please change values before uploading')
|
|
3086
|
+
if is_private != None and not isinstance(is_private, bool):
|
|
3087
|
+
raise ValueError('model.isPrivate must be a boolean')
|
|
3088
|
+
if publish_time:
|
|
3089
|
+
self.validate_date(publish_time)
|
|
3090
|
+
|
|
3091
|
+
# mask
|
|
3092
|
+
update_mask = {'paths': []}
|
|
3093
|
+
if title != None:
|
|
3094
|
+
update_mask['paths'].append('title')
|
|
3095
|
+
if subtitle != None:
|
|
3096
|
+
update_mask['paths'].append('subtitle')
|
|
3097
|
+
if is_private != None:
|
|
3098
|
+
update_mask['paths'].append('isPrivate') # is_private
|
|
3099
|
+
else:
|
|
3100
|
+
is_private = True # default value, not updated
|
|
3101
|
+
if description != None:
|
|
3102
|
+
description = self.sanitize_markdown(description)
|
|
3103
|
+
update_mask['paths'].append('description')
|
|
3104
|
+
if publish_time != None and len(publish_time) > 0:
|
|
3105
|
+
update_mask['paths'].append('publish_time')
|
|
3106
|
+
else:
|
|
3107
|
+
publish_time = None
|
|
3108
|
+
if provenance_sources != None and len(provenance_sources) > 0:
|
|
3109
|
+
update_mask['paths'].append('provenance_sources')
|
|
3110
|
+
else:
|
|
3111
|
+
provenance_sources = None
|
|
3112
|
+
|
|
3113
|
+
with self.build_kaggle_client() as kaggle:
|
|
3114
|
+
fm = field_mask_pb2.FieldMask(paths=update_mask['paths'])
|
|
3115
|
+
fm = fm.FromJsonString(json.dumps(update_mask))
|
|
3116
|
+
request = ApiUpdateModelRequest()
|
|
3117
|
+
request.owner_slug = owner_slug
|
|
3118
|
+
request.model_slug = slug
|
|
3119
|
+
request.title = title
|
|
3120
|
+
request.subtitle = subtitle
|
|
3121
|
+
request.is_private = is_private
|
|
3122
|
+
request.description = description
|
|
3123
|
+
request.publish_time = publish_time
|
|
3124
|
+
request.provenance_sources = provenance_sources
|
|
3125
|
+
request.update_mask = fm if len(update_mask['paths']) > 0 else None
|
|
3126
|
+
return kaggle.models.model_api_client.update_model(request)
|
|
3127
|
+
|
|
3128
|
+
def model_update_cli(self, folder=None):
|
|
3129
|
+
""" Client wrapper for updating a model.
|
|
3130
|
+
Parameters
|
|
3131
|
+
==========
|
|
3132
|
+
folder: the folder to get the metadata file from
|
|
3133
|
+
"""
|
|
3134
|
+
folder = folder or os.getcwd()
|
|
3135
|
+
result = self.model_update(folder)
|
|
3136
|
+
|
|
3137
|
+
if result.hasId:
|
|
3138
|
+
print('Your model was updated. Id={}. Url={}'.format(
|
|
3139
|
+
result.id, result.url))
|
|
3140
|
+
else:
|
|
3141
|
+
print('Model update error: ' + result.error)
|
|
3142
|
+
|
|
3143
|
+
def model_instance_get(self, model_instance):
|
|
3144
|
+
""" Get a model instance.
|
|
3145
|
+
Parameters
|
|
3146
|
+
==========
|
|
3147
|
+
model_instance: the string identifier of the model instance
|
|
3148
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3149
|
+
"""
|
|
3150
|
+
if model_instance is None:
|
|
3151
|
+
raise ValueError('A model instance must be specified')
|
|
3152
|
+
owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
|
|
3153
|
+
model_instance)
|
|
3154
|
+
|
|
3155
|
+
with self.build_kaggle_client() as kaggle:
|
|
3156
|
+
request = ApiGetModelInstanceRequest()
|
|
3157
|
+
request.owner_slug = owner_slug
|
|
3158
|
+
request.model_slug = model_slug
|
|
3159
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3160
|
+
request.instance_slug = instance_slug
|
|
3161
|
+
return kaggle.models.model_api_client.get_model_instance(request)
|
|
3162
|
+
|
|
3163
|
+
def model_instance_get_cli(self, model_instance, folder=None):
|
|
3164
|
+
""" Client wrapper for model_instance_get.
|
|
3165
|
+
Parameters
|
|
3166
|
+
==========
|
|
3167
|
+
model_instance: the string identifier of the model instance
|
|
3168
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3169
|
+
folder: the folder to download the model metadata file
|
|
3170
|
+
"""
|
|
3171
|
+
mi = self.model_instance_get(model_instance)
|
|
3172
|
+
if folder is None:
|
|
3173
|
+
self.print_obj(mi)
|
|
3174
|
+
else:
|
|
3175
|
+
meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE)
|
|
3176
|
+
|
|
3177
|
+
owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
|
|
3178
|
+
model_instance)
|
|
3179
|
+
|
|
3180
|
+
data = {
|
|
3181
|
+
'id': mi.id,
|
|
3182
|
+
'ownerSlug': owner_slug,
|
|
3183
|
+
'modelSlug': model_slug,
|
|
3184
|
+
'instanceSlug': mi.slug,
|
|
3185
|
+
'framework': self.short_enum_name(mi.framework),
|
|
3186
|
+
'overview': mi.overview,
|
|
3187
|
+
'usage': mi.usage,
|
|
3188
|
+
'licenseName': mi.license_name,
|
|
3189
|
+
'fineTunable': mi.fine_tunable,
|
|
3190
|
+
'trainingData': mi.training_data,
|
|
3191
|
+
'versionId': mi.version_id,
|
|
3192
|
+
'versionNumber': mi.version_number,
|
|
3193
|
+
'modelInstanceType': self.short_enum_name(mi.model_instance_type)
|
|
3194
|
+
}
|
|
3195
|
+
if mi.base_model_instance_information is not None:
|
|
3196
|
+
# TODO Test this.
|
|
3197
|
+
data['baseModelInstance'] = '{}/{}/{}/{}'.format(
|
|
3198
|
+
mi.base_model_instance_information['owner']['slug'],
|
|
3199
|
+
mi.base_model_instance_information['modelSlug'],
|
|
3200
|
+
mi.base_model_instance_information['framework'],
|
|
3201
|
+
mi.base_model_instance_information['instanceSlug'])
|
|
3202
|
+
data['externalBaseModelUrl'] = mi.external_base_model_url
|
|
3203
|
+
|
|
3204
|
+
with open(meta_file, 'w') as f:
|
|
3205
|
+
json.dump(data, f, indent=2)
|
|
3206
|
+
print('Metadata file written to {}'.format(meta_file))
|
|
3207
|
+
|
|
3208
|
+
def model_instance_initialize(self, folder):
|
|
3209
|
+
""" Initialize a folder with a model instance configuration (metadata) file.
|
|
3210
|
+
Parameters
|
|
3211
|
+
==========
|
|
3212
|
+
folder: the folder to initialize the metadata file in
|
|
3213
|
+
"""
|
|
3214
|
+
if not os.path.isdir(folder):
|
|
3215
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
3216
|
+
|
|
3217
|
+
meta_data = {
|
|
3218
|
+
'ownerSlug':
|
|
3219
|
+
'INSERT_OWNER_SLUG_HERE',
|
|
3220
|
+
'modelSlug':
|
|
3221
|
+
'INSERT_EXISTING_MODEL_SLUG_HERE',
|
|
3222
|
+
'instanceSlug':
|
|
3223
|
+
'INSERT_INSTANCE_SLUG_HERE',
|
|
3224
|
+
'framework':
|
|
3225
|
+
'INSERT_FRAMEWORK_HERE',
|
|
3226
|
+
'overview':
|
|
3227
|
+
'',
|
|
3228
|
+
'usage':
|
|
3229
|
+
'''# Model Format
|
|
3230
|
+
|
|
3231
|
+
# Training Data
|
|
3232
|
+
|
|
3233
|
+
# Model Inputs
|
|
3234
|
+
|
|
3235
|
+
# Model Outputs
|
|
3236
|
+
|
|
3237
|
+
# Model Usage
|
|
3238
|
+
|
|
3239
|
+
# Fine-tuning
|
|
3240
|
+
|
|
3241
|
+
# Changelog
|
|
3242
|
+
''',
|
|
3243
|
+
'licenseName':
|
|
3244
|
+
'Apache 2.0',
|
|
3245
|
+
'fineTunable':
|
|
3246
|
+
False,
|
|
3247
|
+
'trainingData': [],
|
|
3248
|
+
'modelInstanceType':
|
|
3249
|
+
'Unspecified',
|
|
3250
|
+
'baseModelInstanceId':
|
|
3251
|
+
0,
|
|
3252
|
+
'externalBaseModelUrl':
|
|
3253
|
+
''
|
|
3254
|
+
}
|
|
3255
|
+
meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE)
|
|
3256
|
+
with open(meta_file, 'w') as f:
|
|
3257
|
+
json.dump(meta_data, f, indent=2)
|
|
3258
|
+
|
|
3259
|
+
print('Model Instance template written to: ' + meta_file)
|
|
3260
|
+
return meta_file
|
|
3261
|
+
|
|
3262
|
+
def model_instance_initialize_cli(self, folder):
|
|
3263
|
+
folder = folder or os.getcwd()
|
|
3264
|
+
self.model_instance_initialize(folder)
|
|
3265
|
+
|
|
3266
|
+
def model_instance_create(self, folder, quiet=False, dir_mode='skip'):
|
|
3267
|
+
""" Create a new model instance.
|
|
3268
|
+
Parameters
|
|
3269
|
+
==========
|
|
3270
|
+
folder: the folder to get the metadata file from
|
|
3271
|
+
quiet: suppress verbose output (default is False)
|
|
3272
|
+
dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
3273
|
+
"""
|
|
3274
|
+
if not os.path.isdir(folder):
|
|
3275
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
3276
|
+
|
|
3277
|
+
meta_file = self.get_model_instance_metadata_file(folder)
|
|
3278
|
+
|
|
3279
|
+
# read json
|
|
3280
|
+
with open(meta_file) as f:
|
|
3281
|
+
meta_data = json.load(f)
|
|
3282
|
+
owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
|
|
3283
|
+
model_slug = self.get_or_fail(meta_data, 'modelSlug')
|
|
3284
|
+
instance_slug = self.get_or_fail(meta_data, 'instanceSlug')
|
|
3285
|
+
framework = self.get_or_fail(meta_data, 'framework')
|
|
3286
|
+
overview = self.sanitize_markdown(
|
|
3287
|
+
self.get_or_default(meta_data, 'overview', ''))
|
|
3288
|
+
usage = self.sanitize_markdown(self.get_or_default(meta_data, 'usage', ''))
|
|
3289
|
+
license_name = self.get_or_fail(meta_data, 'licenseName')
|
|
3290
|
+
fine_tunable = self.get_or_default(meta_data, 'fineTunable', False)
|
|
3291
|
+
training_data = self.get_or_default(meta_data, 'trainingData', [])
|
|
3292
|
+
model_instance_type = self.get_or_default(meta_data, 'modelInstanceType',
|
|
3293
|
+
'Unspecified')
|
|
3294
|
+
base_model_instance = self.get_or_default(meta_data, 'baseModelInstance',
|
|
3295
|
+
'')
|
|
3296
|
+
external_base_model_url = self.get_or_default(meta_data,
|
|
3297
|
+
'externalBaseModelUrl', '')
|
|
3298
|
+
|
|
3299
|
+
# validations
|
|
3300
|
+
if owner_slug == 'INSERT_OWNER_SLUG_HERE':
|
|
3301
|
+
raise ValueError(
|
|
3302
|
+
'Default ownerSlug detected, please change values before uploading')
|
|
3303
|
+
if model_slug == 'INSERT_EXISTING_MODEL_SLUG_HERE':
|
|
3304
|
+
raise ValueError(
|
|
3305
|
+
'Default modelSlug detected, please change values before uploading')
|
|
3306
|
+
if instance_slug == 'INSERT_INSTANCE_SLUG_HERE':
|
|
3307
|
+
raise ValueError(
|
|
3308
|
+
'Default instanceSlug detected, please change values before uploading'
|
|
3309
|
+
)
|
|
3310
|
+
if framework == 'INSERT_FRAMEWORK_HERE':
|
|
3311
|
+
raise ValueError(
|
|
3312
|
+
'Default framework detected, please change values before uploading')
|
|
3313
|
+
if license_name == '':
|
|
3314
|
+
raise ValueError('Please specify a license')
|
|
3315
|
+
if not isinstance(fine_tunable, bool):
|
|
3316
|
+
raise ValueError('modelInstance.fineTunable must be a boolean')
|
|
3317
|
+
if not isinstance(training_data, list):
|
|
3318
|
+
raise ValueError('modelInstance.trainingData must be a list')
|
|
3319
|
+
|
|
3320
|
+
body = ApiCreateModelInstanceRequestBody()
|
|
3321
|
+
body.framework = self.lookup_enum(ModelFramework, framework)
|
|
3322
|
+
body.instance_slug = instance_slug
|
|
3323
|
+
body.overview = overview
|
|
3324
|
+
body.usage = usage
|
|
3325
|
+
body.license_name = license_name
|
|
3326
|
+
body.fine_tunable = fine_tunable
|
|
3327
|
+
body.training_data = training_data
|
|
3328
|
+
body.model_instance_type = self.lookup_enum(ModelInstanceType,
|
|
3329
|
+
model_instance_type)
|
|
3330
|
+
body.base_model_instance = base_model_instance
|
|
3331
|
+
body.external_base_model_url = external_base_model_url
|
|
3332
|
+
body.files = []
|
|
3333
|
+
|
|
3334
|
+
with self.build_kaggle_client() as kaggle:
|
|
3335
|
+
request = ApiCreateModelInstanceRequest()
|
|
3336
|
+
request.owner_slug = owner_slug
|
|
3337
|
+
request.model_slug = model_slug
|
|
3338
|
+
request.body = body
|
|
3339
|
+
message = kaggle.models.model_api_client.create_model_instance
|
|
3340
|
+
with ResumableUploadContext() as upload_context:
|
|
3341
|
+
self.upload_files(body, None, folder, ApiBlobType.MODEL, upload_context,
|
|
3342
|
+
quiet, dir_mode)
|
|
3343
|
+
request.body.files = [
|
|
3344
|
+
self._api_dataset_new_file(file) for file in request.body.files
|
|
3345
|
+
]
|
|
3346
|
+
response = self.with_retry(message)(request)
|
|
3347
|
+
return response
|
|
3348
|
+
|
|
3349
|
+
def model_instance_create_cli(self, folder, quiet=False, dir_mode='skip'):
|
|
3350
|
+
""" Client wrapper for creating a new model instance.
|
|
3351
|
+
Parameters
|
|
3352
|
+
==========
|
|
3353
|
+
folder: the folder to get the metadata file from
|
|
3354
|
+
quiet: suppress verbose output (default is False)
|
|
3355
|
+
dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
3356
|
+
"""
|
|
3357
|
+
folder = folder or os.getcwd()
|
|
3358
|
+
result = self.model_instance_create(folder, quiet, dir_mode)
|
|
3359
|
+
|
|
3360
|
+
if result.hasId:
|
|
3361
|
+
print('Your model instance was created. Id={}. Url={}'.format(
|
|
3362
|
+
result.id, result.url))
|
|
3363
|
+
else:
|
|
3364
|
+
print('Model instance creation error: ' + result.error)
|
|
3365
|
+
|
|
3366
|
+
def model_instance_delete(self, model_instance, yes):
|
|
3367
|
+
""" Delete a model instance.
|
|
3368
|
+
Parameters
|
|
3369
|
+
==========
|
|
3370
|
+
model_instance: the string identified of the model instance
|
|
3371
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3372
|
+
yes: automatic confirmation
|
|
3373
|
+
"""
|
|
3374
|
+
if model_instance is None:
|
|
3375
|
+
raise ValueError('A model instance must be specified')
|
|
3376
|
+
owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
|
|
3377
|
+
model_instance)
|
|
3378
|
+
|
|
3379
|
+
if not yes:
|
|
3380
|
+
if not self.confirmation():
|
|
3381
|
+
print('Deletion cancelled')
|
|
3382
|
+
exit(0)
|
|
3383
|
+
|
|
3384
|
+
with self.build_kaggle_client() as kaggle:
|
|
3385
|
+
request = ApiDeleteModelInstanceRequest()
|
|
3386
|
+
request.owner_slug = owner_slug
|
|
3387
|
+
request.model_slug = model_slug
|
|
3388
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3389
|
+
request.instance_slug = instance_slug
|
|
3390
|
+
return kaggle.models.model_api_client.delete_model_instance(request)
|
|
3391
|
+
return res
|
|
3392
|
+
|
|
3393
|
+
def model_instance_delete_cli(self, model_instance, yes):
|
|
3394
|
+
""" Client wrapper for model_instance_delete.
|
|
3395
|
+
Parameters
|
|
3396
|
+
==========
|
|
3397
|
+
model_instance: the string identified of the model instance
|
|
3398
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3399
|
+
yes: automatic confirmation
|
|
3400
|
+
"""
|
|
3401
|
+
result = self.model_instance_delete(model_instance, yes)
|
|
3402
|
+
|
|
3403
|
+
if len(result.error) > 0:
|
|
3404
|
+
print('Model instance deletion error: ' + result.error)
|
|
3405
|
+
else:
|
|
3406
|
+
print('The model instance was deleted.')
|
|
3407
|
+
|
|
3408
|
+
def model_instance_files(self,
|
|
3409
|
+
model_instance,
|
|
3410
|
+
page_token=None,
|
|
3411
|
+
page_size=20,
|
|
3412
|
+
csv_display=False):
|
|
3413
|
+
""" List files for the current version of a model instance.
|
|
3414
|
+
|
|
3415
|
+
Parameters
|
|
3416
|
+
==========
|
|
3417
|
+
model_instance: the string identifier of the model instance
|
|
3418
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3419
|
+
page_token: token for pagination
|
|
3420
|
+
page_size: the number of items per page
|
|
3421
|
+
csv_display: if True, print comma separated values instead of table
|
|
3422
|
+
"""
|
|
3423
|
+
if model_instance is None:
|
|
3424
|
+
raise ValueError('A model_instance must be specified')
|
|
3425
|
+
|
|
3426
|
+
self.validate_model_instance_string(model_instance)
|
|
3427
|
+
urls = model_instance.split('/')
|
|
3428
|
+
[owner_slug, model_slug, framework, instance_slug] = urls
|
|
3429
|
+
|
|
3430
|
+
with self.build_kaggle_client() as kaggle:
|
|
3431
|
+
request = ApiListModelInstanceVersionFilesRequest()
|
|
3432
|
+
request.owner_slug = owner_slug
|
|
3433
|
+
request.model_slug = model_slug
|
|
3434
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3435
|
+
request.instance_slug = instance_slug
|
|
3436
|
+
request.page_size = page_size
|
|
3437
|
+
request.page_token = page_token
|
|
3438
|
+
response = kaggle.models.model_api_client.list_model_instance_version_files(
|
|
3439
|
+
request)
|
|
3440
|
+
|
|
3441
|
+
if response:
|
|
3442
|
+
next_page_token = response.next_page_token
|
|
3443
|
+
if next_page_token:
|
|
3444
|
+
print('Next Page Token = {}'.format(next_page_token))
|
|
3445
|
+
return response
|
|
3446
|
+
else:
|
|
3447
|
+
print('No files found')
|
|
3448
|
+
return FileList({})
|
|
3449
|
+
|
|
3450
|
+
def model_instance_files_cli(self,
|
|
3451
|
+
model_instance,
|
|
3452
|
+
page_token=None,
|
|
3453
|
+
page_size=20,
|
|
3454
|
+
csv_display=False):
|
|
3455
|
+
""" Client wrapper for model_instance_files.
|
|
3456
|
+
|
|
3457
|
+
Parameters
|
|
3458
|
+
==========
|
|
3459
|
+
model_instance: the string identified of the model instance version
|
|
3460
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3461
|
+
page_token: token for pagination
|
|
3462
|
+
page_size: the number of items per page
|
|
3463
|
+
csv_display: if True, print comma separated values instead of table
|
|
3464
|
+
"""
|
|
3465
|
+
result = self.model_instance_files(
|
|
3466
|
+
model_instance,
|
|
3467
|
+
page_token=page_token,
|
|
3468
|
+
page_size=page_size,
|
|
3469
|
+
csv_display=csv_display)
|
|
3470
|
+
if result and result.files is not None:
|
|
3471
|
+
fields = self.dataset_file_fields
|
|
3472
|
+
if csv_display:
|
|
3473
|
+
self.print_csv(result.files, fields)
|
|
3474
|
+
else:
|
|
3475
|
+
self.print_table(result.files, fields)
|
|
3476
|
+
|
|
3477
|
+
def model_instance_update(self, folder):
|
|
3478
|
+
""" Update a model instance.
|
|
3479
|
+
Parameters
|
|
3480
|
+
==========
|
|
3481
|
+
folder: the folder to get the metadata file from
|
|
3482
|
+
"""
|
|
3483
|
+
if not os.path.isdir(folder):
|
|
3484
|
+
raise ValueError('Invalid folder: ' + folder)
|
|
3485
|
+
|
|
3486
|
+
meta_file = self.get_model_instance_metadata_file(folder)
|
|
3487
|
+
|
|
3488
|
+
# read json
|
|
3489
|
+
with open(meta_file) as f:
|
|
3490
|
+
meta_data = json.load(f)
|
|
3491
|
+
owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
|
|
3492
|
+
model_slug = self.get_or_fail(meta_data, 'modelSlug')
|
|
3493
|
+
framework = self.get_or_fail(meta_data, 'framework')
|
|
3494
|
+
instance_slug = self.get_or_fail(meta_data, 'instanceSlug')
|
|
3495
|
+
overview = self.get_or_default(meta_data, 'overview', '')
|
|
3496
|
+
usage = self.get_or_default(meta_data, 'usage', '')
|
|
3497
|
+
license_name = self.get_or_default(meta_data, 'licenseName', None)
|
|
3498
|
+
fine_tunable = self.get_or_default(meta_data, 'fineTunable', None)
|
|
3499
|
+
training_data = self.get_or_default(meta_data, 'trainingData', None)
|
|
3500
|
+
model_instance_type = self.get_or_default(meta_data, 'modelInstanceType',
|
|
3501
|
+
None)
|
|
3502
|
+
base_model_instance = self.get_or_default(meta_data, 'baseModelInstance',
|
|
3503
|
+
None)
|
|
3504
|
+
external_base_model_url = self.get_or_default(meta_data,
|
|
3505
|
+
'externalBaseModelUrl', None)
|
|
3506
|
+
|
|
3507
|
+
# validations
|
|
3508
|
+
if owner_slug == 'INSERT_OWNER_SLUG_HERE':
|
|
3509
|
+
raise ValueError(
|
|
3510
|
+
'Default ownerSlug detected, please change values before uploading')
|
|
3511
|
+
if model_slug == 'INSERT_SLUG_HERE':
|
|
3512
|
+
raise ValueError(
|
|
3513
|
+
'Default model slug detected, please change values before uploading')
|
|
3514
|
+
if instance_slug == 'INSERT_INSTANCE_SLUG_HERE':
|
|
3515
|
+
raise ValueError(
|
|
3516
|
+
'Default instance slug detected, please change values before uploading'
|
|
3517
|
+
)
|
|
3518
|
+
if framework == 'INSERT_FRAMEWORK_HERE':
|
|
3519
|
+
raise ValueError(
|
|
3520
|
+
'Default framework detected, please change values before uploading')
|
|
3521
|
+
if fine_tunable != None and not isinstance(fine_tunable, bool):
|
|
3522
|
+
raise ValueError('modelInstance.fineTunable must be a boolean')
|
|
3523
|
+
if training_data != None and not isinstance(training_data, list):
|
|
3524
|
+
raise ValueError('modelInstance.trainingData must be a list')
|
|
3525
|
+
if model_instance_type:
|
|
3526
|
+
model_instance_type = self.lookup_enum(ModelInstanceType,
|
|
3527
|
+
model_instance_type)
|
|
3528
|
+
|
|
3529
|
+
# mask
|
|
3530
|
+
update_mask = {'paths': []}
|
|
3531
|
+
if overview != None:
|
|
3532
|
+
overview = self.sanitize_markdown(overview)
|
|
3533
|
+
update_mask['paths'].append('overview')
|
|
3534
|
+
if usage != None:
|
|
3535
|
+
usage = self.sanitize_markdown(usage)
|
|
3536
|
+
update_mask['paths'].append('usage')
|
|
3537
|
+
if license_name != None:
|
|
3538
|
+
update_mask['paths'].append('licenseName')
|
|
3539
|
+
else:
|
|
3540
|
+
license_name = "Apache 2.0" # default value even if not updated
|
|
3541
|
+
if fine_tunable != None:
|
|
3542
|
+
update_mask['paths'].append('fineTunable')
|
|
3543
|
+
if training_data != None:
|
|
3544
|
+
update_mask['paths'].append('trainingData')
|
|
3545
|
+
if model_instance_type != None:
|
|
3546
|
+
update_mask['paths'].append('modelInstanceType')
|
|
3547
|
+
if base_model_instance != None:
|
|
3548
|
+
update_mask['paths'].append('baseModelInstance')
|
|
3549
|
+
if external_base_model_url != None:
|
|
3550
|
+
update_mask['paths'].append('externalBaseModelUrl')
|
|
3551
|
+
|
|
3552
|
+
with self.build_kaggle_client() as kaggle:
|
|
3553
|
+
fm = field_mask_pb2.FieldMask(paths=update_mask['paths'])
|
|
3554
|
+
fm = fm.FromJsonString(json.dumps(update_mask))
|
|
3555
|
+
request = ApiUpdateModelInstanceRequest()
|
|
3556
|
+
request.owner_slug = owner_slug
|
|
3557
|
+
request.model_slug = model_slug
|
|
3558
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3559
|
+
request.instance_slug = instance_slug
|
|
3560
|
+
request.overview = overview
|
|
3561
|
+
request.usage = usage
|
|
3562
|
+
request.license_name = license_name
|
|
3563
|
+
request.fine_tunable = fine_tunable
|
|
3564
|
+
request.training_data = training_data
|
|
3565
|
+
request.model_instance_type = model_instance_type
|
|
3566
|
+
request.base_model_instance = base_model_instance
|
|
3567
|
+
request.external_base_model_url = external_base_model_url
|
|
3568
|
+
request.update_mask = fm
|
|
3569
|
+
request.update_mask = fm if len(update_mask['paths']) > 0 else None
|
|
3570
|
+
return kaggle.models.model_api_client.update_model_instance(request)
|
|
3571
|
+
|
|
3572
|
+
def model_instance_update_cli(self, folder=None):
|
|
3573
|
+
""" Client wrapper for updating a model instance.
|
|
3574
|
+
Parameters
|
|
3575
|
+
==========
|
|
3576
|
+
folder: the folder to get the metadata file from
|
|
3577
|
+
"""
|
|
3578
|
+
folder = folder or os.getcwd()
|
|
3579
|
+
result = self.model_instance_update(folder)
|
|
3580
|
+
|
|
3581
|
+
if len(result.error) == 0:
|
|
3582
|
+
print('Your model instance was updated. Id={}. Url={}'.format(
|
|
3583
|
+
result.id, result.url))
|
|
3584
|
+
else:
|
|
3585
|
+
print('Model update error: ' + result.error)
|
|
3586
|
+
|
|
3587
|
+
def model_instance_version_create(self,
|
|
3588
|
+
model_instance,
|
|
3589
|
+
folder,
|
|
3590
|
+
version_notes='',
|
|
3591
|
+
quiet=False,
|
|
3592
|
+
dir_mode='skip'):
|
|
3593
|
+
""" Create a new model instance version.
|
|
3594
|
+
Parameters
|
|
3595
|
+
==========
|
|
3596
|
+
model_instance: the string identified of the model instance
|
|
3597
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3598
|
+
folder: the folder to get the metadata file from
|
|
3599
|
+
version_notes: the version notes to record for this new version
|
|
3600
|
+
quiet: suppress verbose output (default is False)
|
|
3601
|
+
dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
3602
|
+
"""
|
|
3603
|
+
owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
|
|
3604
|
+
model_instance)
|
|
3605
|
+
|
|
3606
|
+
request = ApiCreateModelInstanceVersionRequest()
|
|
3607
|
+
request.owner_slug = owner_slug
|
|
3608
|
+
request.model_slug = model_slug
|
|
3609
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3610
|
+
request.instance_slug = instance_slug
|
|
3611
|
+
body = ApiCreateModelInstanceVersionRequestBody()
|
|
3612
|
+
body.version_notes = version_notes
|
|
3613
|
+
request.body = body
|
|
3614
|
+
with self.build_kaggle_client() as kaggle:
|
|
3615
|
+
message = kaggle.models.model_api_client.create_model_instance_version
|
|
3616
|
+
with ResumableUploadContext() as upload_context:
|
|
3617
|
+
self.upload_files(body, None, folder, ApiBlobType.MODEL, upload_context,
|
|
3618
|
+
quiet, dir_mode)
|
|
3619
|
+
request.body.files = [
|
|
3620
|
+
self._api_dataset_new_file(file) for file in request.body.files
|
|
3621
|
+
]
|
|
3622
|
+
response = self.with_retry(message)(request)
|
|
3623
|
+
return response
|
|
3624
|
+
|
|
3625
|
+
def model_instance_version_create_cli(self,
|
|
3626
|
+
model_instance,
|
|
3627
|
+
folder,
|
|
3628
|
+
version_notes='',
|
|
3629
|
+
quiet=False,
|
|
3630
|
+
dir_mode='skip'):
|
|
3631
|
+
""" Client wrapper for creating a new version of a model instance.
|
|
3632
|
+
Parameters
|
|
3633
|
+
==========
|
|
3634
|
+
model_instance: the string identifier of the model instance
|
|
3635
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]
|
|
3636
|
+
folder: the folder to get the metadata file from
|
|
3637
|
+
version_notes: the version notes to record for this new version
|
|
3638
|
+
quiet: suppress verbose output (default is False)
|
|
3639
|
+
dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
|
|
3640
|
+
"""
|
|
3641
|
+
result = self.model_instance_version_create(model_instance, folder,
|
|
3642
|
+
version_notes, quiet, dir_mode)
|
|
3643
|
+
|
|
3644
|
+
if result.id != 0:
|
|
3645
|
+
print('Your model instance version was created. Url={}'.format(
|
|
3646
|
+
result.url))
|
|
3647
|
+
else:
|
|
3648
|
+
print('Model instance version creation error: ' + result.error)
|
|
3649
|
+
|
|
3650
|
+
def model_instance_version_download(self,
|
|
3651
|
+
model_instance_version,
|
|
3652
|
+
path=None,
|
|
3653
|
+
force=False,
|
|
3654
|
+
quiet=True,
|
|
3655
|
+
untar=False):
|
|
3656
|
+
""" Download all files for a model instance version.
|
|
3657
|
+
|
|
3658
|
+
Parameters
|
|
3659
|
+
==========
|
|
3660
|
+
model_instance_version: the string identifier of the model instance version
|
|
3661
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
|
|
3662
|
+
path: the path to download the model instance version to
|
|
3663
|
+
force: force the download if the file already exists (default False)
|
|
3664
|
+
quiet: suppress verbose output (default is True)
|
|
3665
|
+
untar: if True, untar files upon download (default is False)
|
|
3666
|
+
"""
|
|
3667
|
+
if model_instance_version is None:
|
|
3668
|
+
raise ValueError('A model_instance_version must be specified')
|
|
3669
|
+
|
|
3670
|
+
self.validate_model_instance_version_string(model_instance_version)
|
|
3671
|
+
urls = model_instance_version.split('/')
|
|
3672
|
+
owner_slug = urls[0]
|
|
3673
|
+
model_slug = urls[1]
|
|
3674
|
+
framework = urls[2]
|
|
3675
|
+
instance_slug = urls[3]
|
|
3676
|
+
version_number = urls[4]
|
|
3677
|
+
|
|
3678
|
+
if path is None:
|
|
3679
|
+
effective_path = self.get_default_download_dir('models', owner_slug,
|
|
3680
|
+
model_slug, framework,
|
|
3681
|
+
instance_slug,
|
|
3682
|
+
version_number)
|
|
3683
|
+
else:
|
|
3684
|
+
effective_path = path
|
|
3685
|
+
|
|
3686
|
+
request = ApiDownloadModelInstanceVersionRequest()
|
|
3687
|
+
request.owner_slug = owner_slug
|
|
3688
|
+
request.model_slug = model_slug
|
|
3689
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3690
|
+
request.instance_slug = instance_slug
|
|
3691
|
+
request.version_number = int(version_number)
|
|
3692
|
+
with self.build_kaggle_client() as kaggle:
|
|
3693
|
+
response = kaggle.models.model_api_client.download_model_instance_version(
|
|
3694
|
+
request)
|
|
3695
|
+
|
|
3696
|
+
outfile = os.path.join(effective_path, model_slug + '.tar.gz')
|
|
3697
|
+
if force or self.download_needed(response, outfile, quiet):
|
|
3698
|
+
self.download_file(response, outfile, quiet, not force)
|
|
3699
|
+
downloaded = True
|
|
3700
|
+
else:
|
|
3701
|
+
downloaded = False
|
|
3702
|
+
|
|
3703
|
+
if downloaded:
|
|
3704
|
+
if untar:
|
|
3705
|
+
try:
|
|
3706
|
+
with tarfile.open(outfile, mode='r:gz') as t:
|
|
3707
|
+
t.extractall(effective_path)
|
|
3708
|
+
except Exception as e:
|
|
3709
|
+
raise ValueError(
|
|
3710
|
+
'Error extracting the tar.gz file, please report on '
|
|
3711
|
+
'www.github.com/kaggle/kaggle-api', e)
|
|
3712
|
+
|
|
3713
|
+
try:
|
|
3714
|
+
os.remove(outfile)
|
|
3715
|
+
except OSError as e:
|
|
3716
|
+
print('Could not delete tar file, got %s' % e)
|
|
3717
|
+
return outfile
|
|
3718
|
+
|
|
3719
|
+
def model_instance_version_download_cli(self,
|
|
3720
|
+
model_instance_version,
|
|
3721
|
+
path=None,
|
|
3722
|
+
untar=False,
|
|
3723
|
+
force=False,
|
|
3724
|
+
quiet=False):
|
|
3725
|
+
""" Client wrapper for model_instance_version_download.
|
|
3726
|
+
|
|
3727
|
+
Parameters
|
|
3728
|
+
==========
|
|
3729
|
+
model_instance_version: the string identifier of the model instance version
|
|
3730
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
|
|
3731
|
+
path: the path to download the model instance version to
|
|
3732
|
+
force: force the download if the file already exists (default False)
|
|
3733
|
+
quiet: suppress verbose output (default is False)
|
|
3734
|
+
untar: if True, untar files upon download (default is False)
|
|
3735
|
+
"""
|
|
3736
|
+
return self.model_instance_version_download(
|
|
3737
|
+
model_instance_version,
|
|
3738
|
+
path=path,
|
|
3739
|
+
untar=untar,
|
|
3740
|
+
force=force,
|
|
3741
|
+
quiet=quiet)
|
|
3742
|
+
|
|
3743
|
+
def model_instance_version_files(self,
|
|
3744
|
+
model_instance_version,
|
|
3745
|
+
page_token=None,
|
|
3746
|
+
page_size=20,
|
|
3747
|
+
csv_display=False):
|
|
3748
|
+
""" List all files for a model instance version.
|
|
3749
|
+
|
|
3750
|
+
Parameters
|
|
3751
|
+
==========
|
|
3752
|
+
model_instance_version: the string identifier of the model instance version
|
|
3753
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
|
|
3754
|
+
page_token: token for pagination
|
|
3755
|
+
page_size: the number of items per page
|
|
3756
|
+
csv_display: if True, print comma separated values instead of table
|
|
3757
|
+
"""
|
|
3758
|
+
if model_instance_version is None:
|
|
3759
|
+
raise ValueError('A model_instance_version must be specified')
|
|
3760
|
+
|
|
3761
|
+
self.validate_model_instance_version_string(model_instance_version)
|
|
3762
|
+
urls = model_instance_version.split('/')
|
|
3763
|
+
[owner_slug, model_slug, framework, instance_slug, version_number] = urls
|
|
3764
|
+
|
|
3765
|
+
request = ApiListModelInstanceVersionFilesRequest()
|
|
3766
|
+
request.owner_slug = owner_slug
|
|
3767
|
+
request.model_slug = model_slug
|
|
3768
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3769
|
+
request.instance_slug = instance_slug
|
|
3770
|
+
request.version_number = int(version_number)
|
|
3771
|
+
request.page_size = page_size
|
|
3772
|
+
request.page_token = page_token
|
|
3773
|
+
with self.build_kaggle_client() as kaggle:
|
|
3774
|
+
response = kaggle.models.model_api_client.list_model_instance_version_files(
|
|
3775
|
+
request)
|
|
3776
|
+
|
|
3777
|
+
if response:
|
|
3778
|
+
next_page_token = response.next_page_token
|
|
3779
|
+
if next_page_token:
|
|
3780
|
+
print('Next Page Token = {}'.format(next_page_token))
|
|
3781
|
+
return response
|
|
3782
|
+
else:
|
|
3783
|
+
print('No files found')
|
|
3784
|
+
|
|
3785
|
+
def model_instance_version_files_cli(self,
|
|
3786
|
+
model_instance_version,
|
|
3787
|
+
page_token=None,
|
|
3788
|
+
page_size=20,
|
|
3789
|
+
csv_display=False):
|
|
3790
|
+
""" Client wrapper for model_instance_version_files.
|
|
3791
|
+
|
|
3792
|
+
Parameters
|
|
3793
|
+
==========
|
|
3794
|
+
model_instance_version: the string identifier of the model instance version
|
|
3795
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
|
|
3796
|
+
page_token: token for pagination
|
|
3797
|
+
page_size: the number of items per page
|
|
3798
|
+
csv_display: if True, print comma separated values instead of table
|
|
3799
|
+
"""
|
|
3800
|
+
result = self.model_instance_version_files(
|
|
3801
|
+
model_instance_version,
|
|
3802
|
+
page_token=page_token,
|
|
3803
|
+
page_size=page_size,
|
|
3804
|
+
csv_display=csv_display)
|
|
3805
|
+
if result and result.files is not None:
|
|
3806
|
+
fields = ['name', 'size', 'creation_date']
|
|
3807
|
+
labels = ['name', 'size', 'creationDate']
|
|
3808
|
+
if csv_display:
|
|
3809
|
+
self.print_csv(result.files, fields, labels)
|
|
3810
|
+
else:
|
|
3811
|
+
self.print_table(result.files, fields, labels)
|
|
3812
|
+
|
|
3813
|
+
def model_instance_version_delete(self, model_instance_version, yes):
|
|
3814
|
+
""" Delete a model instance version.
|
|
3815
|
+
Parameters
|
|
3816
|
+
==========
|
|
3817
|
+
model_instance_version: the string identifier of the model instance version
|
|
3818
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
|
|
3819
|
+
yes: automatic confirmation
|
|
3820
|
+
"""
|
|
3821
|
+
if model_instance_version is None:
|
|
3822
|
+
raise ValueError('A model instance version must be specified')
|
|
3823
|
+
|
|
3824
|
+
self.validate_model_instance_version_string(model_instance_version)
|
|
3825
|
+
urls = model_instance_version.split('/')
|
|
3826
|
+
owner_slug = urls[0]
|
|
3827
|
+
model_slug = urls[1]
|
|
3828
|
+
framework = urls[2]
|
|
3829
|
+
instance_slug = urls[3]
|
|
3830
|
+
version_number = urls[4]
|
|
3831
|
+
|
|
3832
|
+
if not yes:
|
|
3833
|
+
if not self.confirmation():
|
|
3834
|
+
print('Deletion cancelled')
|
|
3835
|
+
exit(0)
|
|
3836
|
+
|
|
3837
|
+
request = ApiDeleteModelInstanceVersionRequest()
|
|
3838
|
+
request.owner_slug = owner_slug
|
|
3839
|
+
request.model_slug = model_slug
|
|
3840
|
+
request.framework = self.lookup_enum(ModelFramework, framework)
|
|
3841
|
+
request.instance_slug = instance_slug
|
|
3842
|
+
request.version_number = int(version_number)
|
|
3843
|
+
with self.build_kaggle_client() as kaggle:
|
|
3844
|
+
response = kaggle.models.model_api_client.delete_model_instance_version(
|
|
3845
|
+
request)
|
|
3846
|
+
return response
|
|
3847
|
+
|
|
3848
|
+
def model_instance_version_delete_cli(self, model_instance_version, yes):
|
|
3849
|
+
""" Client wrapper for model_instance_version_delete
|
|
3850
|
+
Parameters
|
|
3851
|
+
==========
|
|
3852
|
+
model_instance_version: the string identified of the model instance version
|
|
3853
|
+
should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
|
|
3854
|
+
yes: automatic confirmation
|
|
3855
|
+
"""
|
|
3856
|
+
result = self.model_instance_version_delete(model_instance_version, yes)
|
|
3857
|
+
|
|
3858
|
+
if len(result.error) > 0:
|
|
3859
|
+
print('Model instance version deletion error: ' + result.error)
|
|
3860
|
+
else:
|
|
3861
|
+
print('The model instance version was deleted.')
|
|
3862
|
+
|
|
3863
|
+
def files_upload_cli(self, local_paths, inbox_path, no_resume, no_compress):
|
|
3864
|
+
if len(local_paths) > self.MAX_NUM_INBOX_FILES_TO_UPLOAD:
|
|
3865
|
+
print('Cannot upload more than %d files!' %
|
|
3866
|
+
self.MAX_NUM_INBOX_FILES_TO_UPLOAD)
|
|
3867
|
+
return
|
|
3868
|
+
|
|
3869
|
+
files_to_create = []
|
|
3870
|
+
with ResumableUploadContext(no_resume) as upload_context:
|
|
3871
|
+
for local_path in local_paths:
|
|
3872
|
+
(upload_file, file_name) = self.file_upload_cli(local_path, inbox_path,
|
|
3873
|
+
no_compress,
|
|
3874
|
+
upload_context)
|
|
3875
|
+
if upload_file is None:
|
|
3876
|
+
continue
|
|
3877
|
+
|
|
3878
|
+
create_inbox_file_request = CreateInboxFileRequest()
|
|
3879
|
+
create_inbox_file_request.virtual_directory = inbox_path
|
|
3880
|
+
create_inbox_file_request.blob_file_token = upload_file.token
|
|
3881
|
+
files_to_create.append((create_inbox_file_request, file_name))
|
|
3882
|
+
|
|
3883
|
+
with self.build_kaggle_client() as kaggle:
|
|
3884
|
+
create_inbox_file = kaggle.admin.inbox_file_client.create_inbox_file
|
|
3885
|
+
for (create_inbox_file_request, file_name) in files_to_create:
|
|
3886
|
+
self.with_retry(create_inbox_file)(create_inbox_file_request)
|
|
3887
|
+
print('Inbox file created:', file_name)
|
|
3888
|
+
|
|
3889
|
+
def file_upload_cli(self, local_path, inbox_path, no_compress,
|
|
3890
|
+
upload_context):
|
|
3891
|
+
full_path = os.path.abspath(local_path)
|
|
3892
|
+
parent_path = os.path.dirname(full_path)
|
|
3893
|
+
file_or_folder_name = os.path.basename(full_path)
|
|
3894
|
+
dir_mode = 'tar' if no_compress else 'zip'
|
|
3895
|
+
|
|
3896
|
+
upload_file = self._upload_file_or_folder(parent_path, file_or_folder_name,
|
|
3897
|
+
ApiBlobType.INBOX, upload_context,
|
|
3898
|
+
dir_mode)
|
|
3899
|
+
return (upload_file, file_or_folder_name)
|
|
3900
|
+
|
|
3901
|
+
def print_obj(self, obj, indent=2):
|
|
3902
|
+
pretty = json.dumps(obj, indent=indent)
|
|
3903
|
+
print(pretty)
|
|
3904
|
+
|
|
3905
|
+
def download_needed(self, response, outfile, quiet=True):
|
|
3906
|
+
""" determine if a download is needed based on timestamp. Return True
|
|
3907
|
+
if needed (remote is newer) or False if local is newest.
|
|
3908
|
+
Parameters
|
|
3909
|
+
==========
|
|
3910
|
+
response: the response from the API
|
|
3911
|
+
outfile: the output file to write to
|
|
3912
|
+
quiet: suppress verbose output (default is True)
|
|
3913
|
+
"""
|
|
3914
|
+
try:
|
|
3915
|
+
last_modified = response.headers.get('Last-Modified')
|
|
3916
|
+
if last_modified is None:
|
|
3917
|
+
remote_date = datetime.now()
|
|
3918
|
+
else:
|
|
3919
|
+
remote_date = datetime.strptime(response.headers['Last-Modified'],
|
|
3920
|
+
'%a, %d %b %Y %H:%M:%S %Z')
|
|
3921
|
+
file_exists = os.path.isfile(outfile)
|
|
3922
|
+
if file_exists:
|
|
3923
|
+
local_date = datetime.fromtimestamp(os.path.getmtime(outfile))
|
|
3924
|
+
remote_size = int(response.headers['Content-Length'])
|
|
3925
|
+
local_size = os.path.getsize(outfile)
|
|
3926
|
+
if local_size < remote_size:
|
|
3927
|
+
return True
|
|
3928
|
+
if remote_date <= local_date:
|
|
3929
|
+
if not quiet:
|
|
3930
|
+
print(
|
|
3931
|
+
os.path.basename(outfile) +
|
|
3932
|
+
': Skipping, found more recently modified local '
|
|
3933
|
+
'copy (use --force to force download)')
|
|
3934
|
+
return False
|
|
3935
|
+
except:
|
|
3936
|
+
pass
|
|
3937
|
+
return True
|
|
3938
|
+
|
|
3939
|
+
def print_table(self, items, fields, labels=None):
|
|
3940
|
+
""" print a table of items, for a set of fields defined
|
|
3941
|
+
|
|
3942
|
+
Parameters
|
|
3943
|
+
==========
|
|
3944
|
+
items: a list of items to print
|
|
3945
|
+
fields: a list of fields to select from items
|
|
3946
|
+
labels: labels for the fields, defaults to fields
|
|
3947
|
+
"""
|
|
3948
|
+
if labels is None:
|
|
3949
|
+
labels = fields
|
|
3950
|
+
formats = []
|
|
3951
|
+
borders = []
|
|
3952
|
+
if len(items) == 0:
|
|
3953
|
+
return
|
|
3954
|
+
for f in fields:
|
|
3955
|
+
length = max(
|
|
3956
|
+
len(f),
|
|
3957
|
+
max([
|
|
3958
|
+
len(self.string(getattr(i, self.camel_to_snake(f))))
|
|
3959
|
+
for i in items
|
|
3960
|
+
]))
|
|
3961
|
+
justify = '>' if isinstance(
|
|
3962
|
+
getattr(items[0], self.camel_to_snake(f)),
|
|
3963
|
+
int) or f == 'size' or f == 'reward' else '<'
|
|
3964
|
+
formats.append('{:' + justify + self.string(length + 2) + '}')
|
|
3965
|
+
borders.append('-' * length + ' ')
|
|
3966
|
+
row_format = u''.join(formats)
|
|
3967
|
+
headers = [f + ' ' for f in labels]
|
|
3968
|
+
print(row_format.format(*headers))
|
|
3969
|
+
print(row_format.format(*borders))
|
|
3970
|
+
for i in items:
|
|
3971
|
+
i_fields = [
|
|
3972
|
+
self.string(getattr(i, self.camel_to_snake(f))) + ' ' for f in fields
|
|
3973
|
+
]
|
|
3974
|
+
try:
|
|
3975
|
+
print(row_format.format(*i_fields))
|
|
3976
|
+
except UnicodeEncodeError:
|
|
3977
|
+
print(row_format.format(*i_fields).encode('utf-8'))
|
|
3978
|
+
|
|
3979
|
+
def print_csv(self, items, fields, labels=None):
|
|
3980
|
+
""" print a set of fields in a set of items using a csv.writer
|
|
3981
|
+
|
|
3982
|
+
Parameters
|
|
3983
|
+
==========
|
|
3984
|
+
items: a list of items to print
|
|
3985
|
+
fields: a list of fields to select from items
|
|
3986
|
+
labels: labels for the fields, defaults to fields
|
|
3987
|
+
"""
|
|
3988
|
+
if labels is None:
|
|
3989
|
+
labels = fields
|
|
3990
|
+
writer = csv.writer(sys.stdout)
|
|
3991
|
+
writer.writerow(labels)
|
|
3992
|
+
for i in items:
|
|
3993
|
+
i_fields = [
|
|
3994
|
+
self.string(getattr(i, self.camel_to_snake(f))) for f in fields
|
|
3995
|
+
]
|
|
3996
|
+
writer.writerow(i_fields)
|
|
3997
|
+
|
|
3998
|
+
def string(self, item):
|
|
3999
|
+
return item if isinstance(item, str) else str(item)
|
|
4000
|
+
|
|
4001
|
+
def get_or_fail(self, data, key):
|
|
4002
|
+
if key in data:
|
|
4003
|
+
return data[key]
|
|
4004
|
+
raise ValueError('Key ' + key + ' not found in data')
|
|
4005
|
+
|
|
4006
|
+
def get_or_default(self, data, key, default):
|
|
4007
|
+
if key in data:
|
|
4008
|
+
return data[key]
|
|
4009
|
+
return default
|
|
4010
|
+
|
|
4011
|
+
def get_bool(self, data, key, default):
|
|
4012
|
+
if key in data:
|
|
4013
|
+
val = data[key]
|
|
4014
|
+
if isinstance(val, str):
|
|
4015
|
+
val = val.lower()
|
|
4016
|
+
if val == 'true':
|
|
4017
|
+
return True
|
|
4018
|
+
elif val == 'false':
|
|
4019
|
+
return False
|
|
4020
|
+
else:
|
|
4021
|
+
raise ValueError('Invalid boolean value: ' + val)
|
|
4022
|
+
if isinstance(val, bool):
|
|
4023
|
+
return val
|
|
4024
|
+
raise ValueError('Invalid boolean value: ' + val)
|
|
4025
|
+
return default
|
|
4026
|
+
|
|
4027
|
+
def set_if_present(self, data, key, output, output_key):
|
|
4028
|
+
if key in data:
|
|
4029
|
+
output[output_key] = data[key]
|
|
4030
|
+
|
|
4031
|
+
def get_dataset_metadata_file(self, folder):
|
|
4032
|
+
meta_file = os.path.join(folder, self.DATASET_METADATA_FILE)
|
|
4033
|
+
if not os.path.isfile(meta_file):
|
|
4034
|
+
meta_file = os.path.join(folder, self.OLD_DATASET_METADATA_FILE)
|
|
4035
|
+
if not os.path.isfile(meta_file):
|
|
4036
|
+
raise ValueError('Metadata file not found: ' +
|
|
4037
|
+
self.DATASET_METADATA_FILE)
|
|
4038
|
+
return meta_file
|
|
4039
|
+
|
|
4040
|
+
def get_model_metadata_file(self, folder):
|
|
4041
|
+
meta_file = os.path.join(folder, self.MODEL_METADATA_FILE)
|
|
4042
|
+
if not os.path.isfile(meta_file):
|
|
4043
|
+
raise ValueError('Metadata file not found: ' + self.MODEL_METADATA_FILE)
|
|
4044
|
+
return meta_file
|
|
4045
|
+
|
|
4046
|
+
def get_model_instance_metadata_file(self, folder):
|
|
4047
|
+
meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE)
|
|
4048
|
+
if not os.path.isfile(meta_file):
|
|
4049
|
+
raise ValueError('Metadata file not found: ' +
|
|
4050
|
+
self.MODEL_INSTANCE_METADATA_FILE)
|
|
4051
|
+
return meta_file
|
|
4052
|
+
|
|
4053
|
+
def process_response(self, result):
|
|
4054
|
+
""" process a response from the API. We check the API version against
|
|
4055
|
+
the client's to see if it's old, and give them a warning (once)
|
|
4056
|
+
|
|
4057
|
+
Parameters
|
|
4058
|
+
==========
|
|
4059
|
+
result: the result from the API
|
|
4060
|
+
"""
|
|
4061
|
+
if len(result) == 3:
|
|
4062
|
+
data = result[0]
|
|
4063
|
+
headers = result[2]
|
|
4064
|
+
if self.HEADER_API_VERSION in headers:
|
|
4065
|
+
api_version = headers[self.HEADER_API_VERSION]
|
|
4066
|
+
if (not self.already_printed_version_warning and
|
|
4067
|
+
not self.is_up_to_date(api_version)):
|
|
4068
|
+
print(f'Warning: Looks like you\'re using an outdated `kaggle`` '
|
|
4069
|
+
'version (installed: {self.__version__}, please consider '
|
|
4070
|
+
'upgrading to the latest version ({api_version})')
|
|
4071
|
+
self.already_printed_version_warning = True
|
|
4072
|
+
if isinstance(data, dict) and 'code' in data and data['code'] != 200:
|
|
4073
|
+
raise Exception(data['message'])
|
|
4074
|
+
return data
|
|
4075
|
+
return result
|
|
4076
|
+
|
|
4077
|
+
def is_up_to_date(self, server_version):
|
|
4078
|
+
""" determine if a client (on the local user's machine) is up to date
|
|
4079
|
+
with the version provided on the server. Return a boolean with True
|
|
4080
|
+
or False
|
|
4081
|
+
Parameters
|
|
4082
|
+
==========
|
|
4083
|
+
server_version: the server version string to compare to the host
|
|
4084
|
+
"""
|
|
4085
|
+
client_split = self.__version__.split('.')
|
|
4086
|
+
client_len = len(client_split)
|
|
4087
|
+
server_split = server_version.split('.')
|
|
4088
|
+
server_len = len(server_split)
|
|
4089
|
+
|
|
4090
|
+
# Make both lists the same length
|
|
4091
|
+
for i in range(client_len, server_len):
|
|
4092
|
+
client_split.append('0')
|
|
4093
|
+
for i in range(server_len, client_len):
|
|
4094
|
+
server_split.append('0')
|
|
4095
|
+
|
|
4096
|
+
for i in range(0, client_len):
|
|
4097
|
+
if 'a' in client_split[i] or 'b' in client_split[i]:
|
|
4098
|
+
# Using a alpha/beta version, don't check
|
|
4099
|
+
return True
|
|
4100
|
+
client = int(client_split[i])
|
|
4101
|
+
server = int(server_split[i])
|
|
4102
|
+
if client < server:
|
|
4103
|
+
return False
|
|
4104
|
+
elif server < client:
|
|
4105
|
+
return True
|
|
4106
|
+
|
|
4107
|
+
return True
|
|
4108
|
+
|
|
4109
|
+
def upload_files(self,
|
|
4110
|
+
request,
|
|
4111
|
+
resources,
|
|
4112
|
+
folder,
|
|
4113
|
+
blob_type,
|
|
4114
|
+
upload_context,
|
|
4115
|
+
quiet=False,
|
|
4116
|
+
dir_mode='skip'):
|
|
4117
|
+
""" upload files in a folder
|
|
4118
|
+
Parameters
|
|
4119
|
+
==========
|
|
4120
|
+
request: the prepared request
|
|
4121
|
+
resources: the files to upload
|
|
4122
|
+
folder: the folder to upload from
|
|
4123
|
+
blob_type (ApiBlobType): To which entity the file/blob refers
|
|
4124
|
+
upload_context (ResumableUploadContext): Context for resumable uploads
|
|
4125
|
+
quiet: suppress verbose output (default is False)
|
|
4126
|
+
"""
|
|
4127
|
+
for file_name in os.listdir(folder):
|
|
4128
|
+
if (file_name in [
|
|
4129
|
+
self.DATASET_METADATA_FILE, self.OLD_DATASET_METADATA_FILE,
|
|
4130
|
+
self.KERNEL_METADATA_FILE, self.MODEL_METADATA_FILE,
|
|
4131
|
+
self.MODEL_INSTANCE_METADATA_FILE
|
|
4132
|
+
]):
|
|
4133
|
+
continue
|
|
4134
|
+
upload_file = self._upload_file_or_folder(folder, file_name, blob_type,
|
|
4135
|
+
upload_context, dir_mode, quiet,
|
|
4136
|
+
resources)
|
|
4137
|
+
if upload_file is not None:
|
|
4138
|
+
request.files.append(upload_file)
|
|
4139
|
+
|
|
4140
|
+
def _upload_file_or_folder(self,
|
|
4141
|
+
parent_path,
|
|
4142
|
+
file_or_folder_name,
|
|
4143
|
+
blob_type,
|
|
4144
|
+
upload_context,
|
|
4145
|
+
dir_mode,
|
|
4146
|
+
quiet=False,
|
|
4147
|
+
resources=None):
|
|
4148
|
+
full_path = os.path.join(parent_path, file_or_folder_name)
|
|
4149
|
+
if os.path.isfile(full_path):
|
|
4150
|
+
return self._upload_file(file_or_folder_name, full_path, blob_type,
|
|
4151
|
+
upload_context, quiet, resources)
|
|
4152
|
+
|
|
4153
|
+
elif os.path.isdir(full_path):
|
|
4154
|
+
if dir_mode in ['zip', 'tar']:
|
|
4155
|
+
with DirectoryArchive(full_path, dir_mode) as archive:
|
|
4156
|
+
return self._upload_file(archive.name, archive.path, blob_type,
|
|
4157
|
+
upload_context, quiet, resources)
|
|
4158
|
+
elif not quiet:
|
|
4159
|
+
print("Skipping folder: " + file_or_folder_name +
|
|
4160
|
+
"; use '--dir-mode' to upload folders")
|
|
4161
|
+
else:
|
|
4162
|
+
if not quiet:
|
|
4163
|
+
print('Skipping: ' + file_or_folder_name)
|
|
4164
|
+
return None
|
|
4165
|
+
|
|
4166
|
+
def _upload_file(self, file_name, full_path, blob_type, upload_context, quiet,
|
|
4167
|
+
resources):
|
|
4168
|
+
""" Helper function to upload a single file
|
|
4169
|
+
Parameters
|
|
4170
|
+
==========
|
|
4171
|
+
file_name: name of the file to upload
|
|
4172
|
+
full_path: path to the file to upload
|
|
4173
|
+
blob_type (ApiBlobType): To which entity the file/blob refers
|
|
4174
|
+
upload_context (ResumableUploadContext): Context for resumable uploads
|
|
4175
|
+
quiet: suppress verbose output
|
|
4176
|
+
resources: optional file metadata
|
|
4177
|
+
:return: None - upload unsuccessful; instance of UploadFile - upload successful
|
|
4178
|
+
"""
|
|
4179
|
+
|
|
4180
|
+
if not quiet:
|
|
4181
|
+
print('Starting upload for file ' + file_name)
|
|
4182
|
+
|
|
4183
|
+
content_length = os.path.getsize(full_path)
|
|
4184
|
+
token = self._upload_blob(full_path, quiet, blob_type, upload_context)
|
|
4185
|
+
if token is None:
|
|
4186
|
+
if not quiet:
|
|
4187
|
+
print('Upload unsuccessful: ' + file_name)
|
|
4188
|
+
return None
|
|
4189
|
+
if not quiet:
|
|
4190
|
+
print('Upload successful: ' + file_name + ' (' +
|
|
4191
|
+
File.get_size(content_length) + ')')
|
|
4192
|
+
upload_file = UploadFile()
|
|
4193
|
+
upload_file.token = token
|
|
4194
|
+
if resources:
|
|
4195
|
+
for item in resources:
|
|
4196
|
+
if file_name == item.get('path'):
|
|
4197
|
+
upload_file.description = item.get('description')
|
|
4198
|
+
if 'schema' in item:
|
|
4199
|
+
fields = self.get_or_default(item['schema'], 'fields', [])
|
|
4200
|
+
processed = []
|
|
4201
|
+
count = 0
|
|
4202
|
+
for field in fields:
|
|
4203
|
+
processed.append(self.process_column(field))
|
|
4204
|
+
processed[count].order = count
|
|
4205
|
+
count += 1
|
|
4206
|
+
upload_file.columns = processed
|
|
4207
|
+
return upload_file
|
|
4208
|
+
|
|
4209
|
+
def process_column(self, column):
|
|
4210
|
+
""" process a column, check for the type, and return the processed
|
|
4211
|
+
column
|
|
4212
|
+
Parameters
|
|
4213
|
+
==========
|
|
4214
|
+
column: a list of values in a column to be processed
|
|
4215
|
+
"""
|
|
4216
|
+
processed_column = DatasetColumn(
|
|
4217
|
+
name=self.get_or_fail(column, 'name'),
|
|
4218
|
+
description=self.get_or_default(column, 'description', ''))
|
|
4219
|
+
if 'type' in column:
|
|
4220
|
+
original_type = column['type'].lower()
|
|
4221
|
+
processed_column.original_type = original_type
|
|
4222
|
+
if (original_type == 'string' or original_type == 'date' or
|
|
4223
|
+
original_type == 'time' or original_type == 'yearmonth' or
|
|
4224
|
+
original_type == 'duration' or original_type == 'geopoint' or
|
|
4225
|
+
original_type == 'geojson'):
|
|
4226
|
+
processed_column.type = 'string'
|
|
4227
|
+
elif (original_type == 'numeric' or original_type == 'number' or
|
|
4228
|
+
original_type == 'year'):
|
|
4229
|
+
processed_column.type = 'numeric'
|
|
4230
|
+
elif original_type == 'boolean':
|
|
4231
|
+
processed_column.type = 'boolean'
|
|
4232
|
+
elif original_type == 'datetime':
|
|
4233
|
+
processed_column.type = 'datetime'
|
|
4234
|
+
else:
|
|
4235
|
+
# Possibly extended data type - not going to try to track those
|
|
4236
|
+
# here. Will set the type and let the server handle it.
|
|
4237
|
+
processed_column.type = original_type
|
|
4238
|
+
return processed_column
|
|
4239
|
+
|
|
4240
|
+
def upload_complete(self, path, url, quiet, resume=False):
|
|
4241
|
+
""" function to complete an upload to retrieve a path from a url
|
|
4242
|
+
Parameters
|
|
4243
|
+
==========
|
|
4244
|
+
path: the path for the upload that is read in
|
|
4245
|
+
url: the url to send the POST to
|
|
4246
|
+
quiet: suppress verbose output (default is False)
|
|
4247
|
+
"""
|
|
4248
|
+
file_size = os.path.getsize(path)
|
|
4249
|
+
resumable_upload_result = ResumableUploadResult.Incomplete()
|
|
4250
|
+
|
|
4251
|
+
try:
|
|
4252
|
+
if resume:
|
|
4253
|
+
resumable_upload_result = self._resume_upload(path, url, file_size,
|
|
4254
|
+
quiet)
|
|
4255
|
+
if resumable_upload_result.result != ResumableUploadResult.INCOMPLETE:
|
|
4256
|
+
return resumable_upload_result.result
|
|
4257
|
+
|
|
4258
|
+
start_at = resumable_upload_result.start_at
|
|
4259
|
+
upload_size = file_size - start_at
|
|
4260
|
+
|
|
4261
|
+
with tqdm(
|
|
4262
|
+
total=upload_size,
|
|
4263
|
+
unit='B',
|
|
4264
|
+
unit_scale=True,
|
|
4265
|
+
unit_divisor=1024,
|
|
4266
|
+
disable=quiet) as progress_bar:
|
|
4267
|
+
with io.open(path, 'rb', buffering=0) as fp:
|
|
4268
|
+
session = requests.Session()
|
|
4269
|
+
if start_at > 0:
|
|
4270
|
+
fp.seek(start_at)
|
|
4271
|
+
session.headers.update({
|
|
4272
|
+
'Content-Length':
|
|
4273
|
+
'%d' % upload_size,
|
|
4274
|
+
'Content-Range':
|
|
4275
|
+
'bytes %d-%d/%d' % (start_at, file_size - 1, file_size)
|
|
4276
|
+
})
|
|
4277
|
+
reader = TqdmBufferedReader(fp, progress_bar)
|
|
4278
|
+
retries = Retry(total=10, backoff_factor=0.5)
|
|
4279
|
+
adapter = HTTPAdapter(max_retries=retries)
|
|
4280
|
+
session.mount('http://', adapter)
|
|
4281
|
+
session.mount('https://', adapter)
|
|
4282
|
+
response = session.put(url, data=reader)
|
|
4283
|
+
if self._is_upload_successful(response):
|
|
4284
|
+
return ResumableUploadResult.COMPLETE
|
|
4285
|
+
if response.status_code == 503:
|
|
4286
|
+
return ResumableUploadResult.INCOMPLETE
|
|
4287
|
+
# Server returned a non-resumable error so give up.
|
|
4288
|
+
return ResumableUploadResult.FAILED
|
|
4289
|
+
except Exception as error:
|
|
4290
|
+
print(error)
|
|
4291
|
+
# There is probably some weird bug in our code so try to resume the upload
|
|
4292
|
+
# in case it works on the next try.
|
|
4293
|
+
return ResumableUploadResult.INCOMPLETE
|
|
4294
|
+
|
|
4295
|
+
def _resume_upload(self, path, url, content_length, quiet):
|
|
4296
|
+
# Documentation: https://developers.google.com/drive/api/guides/manage-uploads#resume-upload
|
|
4297
|
+
session = requests.Session()
|
|
4298
|
+
session.headers.update({
|
|
4299
|
+
'Content-Length': '0',
|
|
4300
|
+
'Content-Range': 'bytes */%d' % content_length,
|
|
4301
|
+
})
|
|
4302
|
+
|
|
4303
|
+
response = session.put(url)
|
|
4304
|
+
|
|
4305
|
+
if self._is_upload_successful(response):
|
|
4306
|
+
return ResumableUploadResult.Complete()
|
|
4307
|
+
if response.status_code == 404:
|
|
4308
|
+
# Upload expired so need to start from scratch.
|
|
4309
|
+
if not quiet:
|
|
4310
|
+
print('Upload of %s expired. Please try again.' % path)
|
|
4311
|
+
return ResumableUploadResult.Failed()
|
|
4312
|
+
if response.status_code == 308: # Resume Incomplete
|
|
4313
|
+
bytes_uploaded = self._get_bytes_already_uploaded(response, quiet)
|
|
4314
|
+
if bytes_uploaded is None:
|
|
4315
|
+
# There is an error with the Range header so need to start from scratch.
|
|
4316
|
+
return ResumableUploadResult.Failed()
|
|
4317
|
+
result = ResumableUploadResult.Incomplete(bytes_uploaded)
|
|
4318
|
+
if not quiet:
|
|
4319
|
+
print('Already uploaded %d bytes. Will resume upload at %d.' %
|
|
4320
|
+
(result.bytes_uploaded, result.start_at))
|
|
4321
|
+
return result
|
|
4322
|
+
else:
|
|
4323
|
+
if not quiet:
|
|
4324
|
+
print('Server returned %d. Please try again.' % response.status_code)
|
|
4325
|
+
return ResumableUploadResult.Failed()
|
|
4326
|
+
|
|
4327
|
+
def _is_upload_successful(self, response):
|
|
4328
|
+
return response.status_code == 200 or response.status_code == 201
|
|
4329
|
+
|
|
4330
|
+
def _get_bytes_already_uploaded(self, response, quiet):
|
|
4331
|
+
range_val = response.headers.get('Range')
|
|
4332
|
+
if range_val is None:
|
|
4333
|
+
return 0 # This means server hasn't received anything before.
|
|
4334
|
+
items = range_val.split('-') # Example: bytes=0-1000 => ['0', '1000']
|
|
4335
|
+
if len(items) != 2:
|
|
4336
|
+
if not quiet:
|
|
4337
|
+
print('Invalid Range header format: %s. Will try again.' % range_val)
|
|
4338
|
+
return None # Shouldn't happen, something's wrong with Range header format.
|
|
4339
|
+
bytes_uploaded_str = items[-1] # Example: ['0', '1000'] => '1000'
|
|
4340
|
+
try:
|
|
4341
|
+
return int(bytes_uploaded_str) # Example: '1000' => 1000
|
|
4342
|
+
except ValueError:
|
|
4343
|
+
if not quiet:
|
|
4344
|
+
print('Invalid Range header format: %s. Will try again.' % range_val)
|
|
4345
|
+
return None # Shouldn't happen, something's wrong with Range header format.
|
|
4346
|
+
|
|
4347
|
+
def validate_dataset_string(self, dataset):
|
|
4348
|
+
""" determine if a dataset string is valid, meaning it is in the format
|
|
4349
|
+
of {username}/{dataset-slug} or {username}/{dataset-slug}/{version-number}.
|
|
4350
|
+
Parameters
|
|
4351
|
+
==========
|
|
4352
|
+
dataset: the dataset name to validate
|
|
4353
|
+
"""
|
|
4354
|
+
if dataset:
|
|
4355
|
+
if '/' not in dataset:
|
|
4356
|
+
raise ValueError('Dataset must be specified in the form of '
|
|
4357
|
+
'\'{username}/{dataset-slug}\'')
|
|
4358
|
+
|
|
4359
|
+
split = dataset.split('/')
|
|
4360
|
+
if not split[0] or not split[1] or len(split) > 3:
|
|
4361
|
+
raise ValueError('Invalid dataset specification ' + dataset)
|
|
4362
|
+
|
|
4363
|
+
def split_dataset_string(self, dataset):
|
|
4364
|
+
""" split a dataset string into owner_slug, dataset_slug,
|
|
4365
|
+
and optional version_number
|
|
4366
|
+
Parameters
|
|
4367
|
+
==========
|
|
4368
|
+
dataset: the dataset name to split
|
|
4369
|
+
"""
|
|
4370
|
+
if '/' in dataset:
|
|
4371
|
+
self.validate_dataset_string(dataset)
|
|
4372
|
+
urls = dataset.split('/')
|
|
4373
|
+
if len(urls) == 3:
|
|
4374
|
+
return urls[0], urls[1], urls[2]
|
|
4375
|
+
else:
|
|
4376
|
+
return urls[0], urls[1], None
|
|
4377
|
+
else:
|
|
4378
|
+
return self.get_config_value(self.CONFIG_NAME_USER), dataset, None
|
|
4379
|
+
|
|
4380
|
+
def validate_model_string(self, model):
|
|
4381
|
+
""" determine if a model string is valid, meaning it is in the format
|
|
4382
|
+
of {owner}/{model-slug}.
|
|
4383
|
+
Parameters
|
|
4384
|
+
==========
|
|
4385
|
+
model: the model name to validate
|
|
4386
|
+
"""
|
|
4387
|
+
if model:
|
|
4388
|
+
if model.count('/') != 1:
|
|
4389
|
+
raise ValueError('Model must be specified in the form of '
|
|
4390
|
+
'\'{owner}/{model-slug}\'')
|
|
4391
|
+
|
|
4392
|
+
split = model.split('/')
|
|
4393
|
+
if not split[0] or not split[1]:
|
|
4394
|
+
raise ValueError('Invalid model specification ' + model)
|
|
4395
|
+
|
|
4396
|
+
def split_model_string(self, model):
|
|
4397
|
+
""" split a model string into owner_slug, model_slug
|
|
4398
|
+
Parameters
|
|
4399
|
+
==========
|
|
4400
|
+
model: the model name to split
|
|
4401
|
+
"""
|
|
4402
|
+
if '/' in model:
|
|
4403
|
+
self.validate_model_string(model)
|
|
4404
|
+
model_urls = model.split('/')
|
|
4405
|
+
return model_urls[0], model_urls[1]
|
|
4406
|
+
else:
|
|
4407
|
+
return self.get_config_value(self.CONFIG_NAME_USER), model
|
|
4408
|
+
|
|
4409
|
+
def validate_model_instance_string(self, model_instance):
|
|
4410
|
+
""" determine if a model instance string is valid, meaning it is in the format
|
|
4411
|
+
of {owner}/{model-slug}/{framework}/{instance-slug}.
|
|
4412
|
+
Parameters
|
|
4413
|
+
==========
|
|
4414
|
+
model_instance: the model instance name to validate
|
|
4415
|
+
"""
|
|
4416
|
+
if model_instance:
|
|
4417
|
+
if model_instance.count('/') != 3:
|
|
4418
|
+
raise ValueError('Model instance must be specified in the form of '
|
|
4419
|
+
'\'{owner}/{model-slug}/{framework}/{instance-slug}\'')
|
|
4420
|
+
|
|
4421
|
+
split = model_instance.split('/')
|
|
4422
|
+
if not split[0] or not split[1] or not split[2] or not split[3]:
|
|
4423
|
+
raise ValueError('Invalid model instance specification ' +
|
|
4424
|
+
model_instance)
|
|
4425
|
+
|
|
4426
|
+
def split_model_instance_string(self, model_instance):
|
|
4427
|
+
""" split a model instance string into owner_slug, model_slug,
|
|
4428
|
+
framework, instance_slug
|
|
4429
|
+
Parameters
|
|
4430
|
+
==========
|
|
4431
|
+
model_instance: the model instance name to validate
|
|
4432
|
+
"""
|
|
4433
|
+
self.validate_model_instance_string(model_instance)
|
|
4434
|
+
urls = model_instance.split('/')
|
|
4435
|
+
return urls[0], urls[1], urls[2], urls[3]
|
|
4436
|
+
|
|
4437
|
+
def validate_model_instance_version_string(self, model_instance_version):
|
|
4438
|
+
""" determine if a model instance version string is valid, meaning it is in the format
|
|
4439
|
+
of {owner}/{model-slug}/{framework}/{instance-slug}/{version-number}.
|
|
4440
|
+
Parameters
|
|
4441
|
+
==========
|
|
4442
|
+
model_instance_version: the model instance version name to validate
|
|
4443
|
+
"""
|
|
4444
|
+
if model_instance_version:
|
|
4445
|
+
if model_instance_version.count('/') != 4:
|
|
4446
|
+
raise ValueError(
|
|
4447
|
+
'Model instance version must be specified in the form of '
|
|
4448
|
+
'\'{owner}/{model-slug}/{framework}/{instance-slug}/{version-number}\''
|
|
4449
|
+
)
|
|
4450
|
+
|
|
4451
|
+
split = model_instance_version.split('/')
|
|
4452
|
+
if not split[0] or not split[1] or not split[2] or not split[
|
|
4453
|
+
3] or not split[4]:
|
|
4454
|
+
raise ValueError('Invalid model instance version specification ' +
|
|
4455
|
+
model_instance_version)
|
|
4456
|
+
|
|
4457
|
+
try:
|
|
4458
|
+
version_number = int(split[4])
|
|
4459
|
+
except:
|
|
4460
|
+
raise ValueError(
|
|
4461
|
+
'Model instance version\'s version-number must be an integer')
|
|
4462
|
+
|
|
4463
|
+
def validate_kernel_string(self, kernel):
|
|
4464
|
+
""" determine if a kernel string is valid, meaning it is in the format
|
|
4465
|
+
of {username}/{kernel-slug}.
|
|
4466
|
+
Parameters
|
|
4467
|
+
==========
|
|
4468
|
+
kernel: the kernel name to validate
|
|
4469
|
+
"""
|
|
4470
|
+
if kernel:
|
|
4471
|
+
if '/' not in kernel:
|
|
4472
|
+
raise ValueError('Kernel must be specified in the form of '
|
|
4473
|
+
'\'{username}/{kernel-slug}\'')
|
|
4474
|
+
|
|
4475
|
+
split = kernel.split('/')
|
|
4476
|
+
if not split[0] or not split[1]:
|
|
4477
|
+
raise ValueError('Kernel must be specified in the form of '
|
|
4478
|
+
'\'{username}/{kernel-slug}\'')
|
|
4479
|
+
|
|
4480
|
+
if len(split[1]) < 5:
|
|
4481
|
+
raise ValueError('Kernel slug must be at least five characters')
|
|
4482
|
+
|
|
4483
|
+
def validate_model_string(self, model):
|
|
4484
|
+
""" determine if a model string is valid, meaning it is in the format
|
|
4485
|
+
of {username}/{model-slug}/{framework}/{variation-slug}/{version-number}.
|
|
4486
|
+
Parameters
|
|
4487
|
+
==========
|
|
4488
|
+
model: the model name to validate
|
|
4489
|
+
"""
|
|
4490
|
+
if model:
|
|
4491
|
+
if '/' not in model:
|
|
4492
|
+
raise ValueError(
|
|
4493
|
+
'Model must be specified in the form of '
|
|
4494
|
+
'\'{username}/{model-slug}/{framework}/{variation-slug}/{version-number}\''
|
|
4495
|
+
)
|
|
4496
|
+
|
|
4497
|
+
split = model.split('/')
|
|
4498
|
+
if not split[0] or not split[1]:
|
|
4499
|
+
raise ValueError('Invalid model specification ' + model)
|
|
4500
|
+
|
|
4501
|
+
def validate_resources(self, folder, resources):
|
|
4502
|
+
""" validate resources is a wrapper to validate the existence of files
|
|
4503
|
+
and that there are no duplicates for a folder and set of resources.
|
|
4504
|
+
|
|
4505
|
+
Parameters
|
|
4506
|
+
==========
|
|
4507
|
+
folder: the folder to validate
|
|
4508
|
+
resources: one or more resources to validate within the folder
|
|
4509
|
+
"""
|
|
4510
|
+
self.validate_files_exist(folder, resources)
|
|
4511
|
+
self.validate_no_duplicate_paths(resources)
|
|
4512
|
+
|
|
4513
|
+
def validate_files_exist(self, folder, resources):
|
|
4514
|
+
""" ensure that one or more resource files exist in a folder
|
|
4515
|
+
|
|
4516
|
+
Parameters
|
|
4517
|
+
==========
|
|
4518
|
+
folder: the folder to validate
|
|
4519
|
+
resources: one or more resources to validate within the folder
|
|
4520
|
+
"""
|
|
4521
|
+
for item in resources:
|
|
4522
|
+
file_name = item.get('path')
|
|
4523
|
+
full_path = os.path.join(folder, file_name)
|
|
4524
|
+
if not os.path.isfile(full_path):
|
|
4525
|
+
raise ValueError('%s does not exist' % full_path)
|
|
4526
|
+
|
|
4527
|
+
def validate_no_duplicate_paths(self, resources):
|
|
4528
|
+
""" ensure that the user has not provided duplicate paths in
|
|
4529
|
+
a list of resources.
|
|
4530
|
+
|
|
4531
|
+
Parameters
|
|
4532
|
+
==========
|
|
4533
|
+
resources: one or more resources to validate not duplicated
|
|
4534
|
+
"""
|
|
4535
|
+
paths = set()
|
|
4536
|
+
for item in resources:
|
|
4537
|
+
file_name = item.get('path')
|
|
4538
|
+
if file_name in paths:
|
|
4539
|
+
raise ValueError(
|
|
4540
|
+
'%s path was specified more than once in the metadata' % file_name)
|
|
4541
|
+
paths.add(file_name)
|
|
4542
|
+
|
|
4543
|
+
def convert_to_dataset_file_metadata(self, file_data, path):
|
|
4544
|
+
""" convert a set of file_data to a metadata file at path
|
|
4545
|
+
|
|
4546
|
+
Parameters
|
|
4547
|
+
==========
|
|
4548
|
+
file_data: a dictionary of file data to write to file
|
|
4549
|
+
path: the path to write the metadata to
|
|
4550
|
+
"""
|
|
4551
|
+
as_metadata = {
|
|
4552
|
+
'path': os.path.join(path, file_data['name']),
|
|
4553
|
+
'description': file_data['description']
|
|
4554
|
+
}
|
|
4555
|
+
|
|
4556
|
+
schema = {}
|
|
4557
|
+
fields = []
|
|
4558
|
+
for column in file_data['columns']:
|
|
4559
|
+
field = {
|
|
4560
|
+
'name': column['name'],
|
|
4561
|
+
'title': column['description'],
|
|
4562
|
+
'type': column['type']
|
|
4563
|
+
}
|
|
4564
|
+
fields.append(field)
|
|
4565
|
+
schema['fields'] = fields
|
|
4566
|
+
as_metadata['schema'] = schema
|
|
4567
|
+
|
|
4568
|
+
return as_metadata
|
|
4569
|
+
|
|
4570
|
+
def validate_date(self, date):
|
|
4571
|
+
datetime.strptime(date, "%Y-%m-%d")
|
|
4572
|
+
|
|
4573
|
+
def sanitize_markdown(self, markdown):
|
|
4574
|
+
return bleach.clean(markdown)
|
|
4575
|
+
|
|
4576
|
+
def confirmation(self):
|
|
4577
|
+
question = "Are you sure?"
|
|
4578
|
+
prompt = "[yes/no]"
|
|
4579
|
+
options = {"yes": True, "no": False}
|
|
4580
|
+
while True:
|
|
4581
|
+
sys.stdout.write('{} {} '.format(question, prompt))
|
|
4582
|
+
choice = input().lower()
|
|
4583
|
+
if choice in options:
|
|
4584
|
+
return options[choice]
|
|
4585
|
+
else:
|
|
4586
|
+
sys.stdout.write("Please respond with 'yes' or 'no'.\n")
|
|
4587
|
+
return False
|
|
4588
|
+
|
|
4589
|
+
|
|
4590
|
+
class TqdmBufferedReader(io.BufferedReader):
|
|
4591
|
+
|
|
4592
|
+
def __init__(self, raw, progress_bar):
|
|
4593
|
+
""" helper class to implement an io.BufferedReader
|
|
4594
|
+
Parameters
|
|
4595
|
+
==========
|
|
4596
|
+
raw: bytes data to pass to the buffered reader
|
|
4597
|
+
progress_bar: a progress bar to initialize the reader
|
|
4598
|
+
"""
|
|
4599
|
+
io.BufferedReader.__init__(self, raw)
|
|
4600
|
+
self.progress_bar = progress_bar
|
|
4601
|
+
|
|
4602
|
+
def read(self, *args, **kwargs):
|
|
4603
|
+
""" read the buffer, passing named and non named arguments to the
|
|
4604
|
+
io.BufferedReader function.
|
|
4605
|
+
"""
|
|
4606
|
+
buf = io.BufferedReader.read(self, *args, **kwargs)
|
|
4607
|
+
self.increment(len(buf))
|
|
4608
|
+
return buf
|
|
4609
|
+
|
|
4610
|
+
def increment(self, length):
|
|
4611
|
+
""" increment the reader by some length
|
|
4612
|
+
|
|
4613
|
+
Parameters
|
|
4614
|
+
==========
|
|
4615
|
+
length: bytes to increment the reader by
|
|
4616
|
+
"""
|
|
4617
|
+
self.progress_bar.update(length)
|
|
4618
|
+
|
|
4619
|
+
|
|
4620
|
+
class FileList(object):
|
|
4621
|
+
|
|
4622
|
+
def __init__(self, init_dict):
|
|
4623
|
+
self.error_message = ''
|
|
4624
|
+
files = init_dict['files']
|
|
4625
|
+
if files:
|
|
4626
|
+
for f in files:
|
|
4627
|
+
if 'size' in f:
|
|
4628
|
+
f['totalBytes'] = f['size']
|
|
4629
|
+
self.files = [File(f) for f in files]
|
|
4630
|
+
else:
|
|
4631
|
+
self.files = []
|
|
4632
|
+
token = init_dict['nextPageToken']
|
|
4633
|
+
if token:
|
|
4634
|
+
self.nextPageToken = token
|
|
4635
|
+
else:
|
|
4636
|
+
self.nextPageToken = ""
|
|
4637
|
+
|
|
4638
|
+
def __repr__(self):
|
|
4639
|
+
return ''
|
|
4640
|
+
|
|
4641
|
+
# This defines print_attributes(), which is very handy for inspecting
|
|
4642
|
+
# objects returned by the Kaggle API.
|
|
4643
|
+
|
|
4644
|
+
from pprint import pprint
|
|
4645
|
+
from inspect import getmembers
|
|
4646
|
+
from types import FunctionType
|
|
4647
|
+
|
|
4648
|
+
def attributes(obj):
|
|
4649
|
+
disallowed_names = {
|
|
4650
|
+
name for name, value in getmembers(type(obj))
|
|
4651
|
+
if isinstance(value, FunctionType)}
|
|
4652
|
+
return {
|
|
4653
|
+
name: getattr(obj, name) for name in dir(obj)
|
|
4654
|
+
if name[0] != '_' and name not in disallowed_names and hasattr(obj, name)}
|
|
4655
|
+
|
|
4656
|
+
def print_attributes(obj):
|
|
4657
|
+
pprint(attributes(obj))
|