codeshift 0.4.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/__init__.py +1 -1
- codeshift/cli/commands/auth.py +41 -25
- codeshift/cli/commands/health.py +244 -0
- codeshift/cli/commands/upgrade.py +68 -55
- codeshift/cli/main.py +2 -0
- codeshift/health/__init__.py +50 -0
- codeshift/health/calculator.py +217 -0
- codeshift/health/metrics/__init__.py +63 -0
- codeshift/health/metrics/documentation.py +209 -0
- codeshift/health/metrics/freshness.py +180 -0
- codeshift/health/metrics/migration_readiness.py +142 -0
- codeshift/health/metrics/security.py +225 -0
- codeshift/health/metrics/test_coverage.py +191 -0
- codeshift/health/models.py +284 -0
- codeshift/health/report.py +310 -0
- codeshift/knowledge/generator.py +6 -0
- codeshift/knowledge_base/libraries/aiohttp.yaml +3 -3
- codeshift/knowledge_base/libraries/httpx.yaml +4 -4
- codeshift/knowledge_base/libraries/pytest.yaml +1 -1
- codeshift/knowledge_base/models.py +1 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +50 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +191 -22
- codeshift/scanner/code_scanner.py +22 -2
- codeshift/utils/api_client.py +144 -4
- codeshift/utils/credential_store.py +393 -0
- codeshift/utils/llm_client.py +111 -9
- {codeshift-0.4.0.dist-info → codeshift-0.7.0.dist-info}/METADATA +4 -1
- {codeshift-0.4.0.dist-info → codeshift-0.7.0.dist-info}/RECORD +32 -20
- {codeshift-0.4.0.dist-info → codeshift-0.7.0.dist-info}/WHEEL +0 -0
- {codeshift-0.4.0.dist-info → codeshift-0.7.0.dist-info}/entry_points.txt +0 -0
- {codeshift-0.4.0.dist-info → codeshift-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.4.0.dist-info → codeshift-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -191,6 +191,10 @@ class MarshmallowTransformer(BaseTransformer):
|
|
|
191
191
|
- default -> dump_default
|
|
192
192
|
- load_from -> data_key
|
|
193
193
|
- dump_to -> data_key
|
|
194
|
+
|
|
195
|
+
Special handling: When both load_from and dump_to are present, only one data_key
|
|
196
|
+
is kept (preferring load_from) and a warning comment is added about the removed
|
|
197
|
+
dump_to value.
|
|
194
198
|
"""
|
|
195
199
|
# Check if this is a fields.* call or a Field-like call
|
|
196
200
|
func_name = self._get_call_func_name(node.func)
|
|
@@ -232,6 +236,27 @@ class MarshmallowTransformer(BaseTransformer):
|
|
|
232
236
|
if func_name not in field_types:
|
|
233
237
|
return node
|
|
234
238
|
|
|
239
|
+
# First pass: detect if both load_from and dump_to are present
|
|
240
|
+
load_from_arg = None
|
|
241
|
+
dump_to_arg = None
|
|
242
|
+
load_from_value = None
|
|
243
|
+
dump_to_value = None
|
|
244
|
+
|
|
245
|
+
for arg in node.args:
|
|
246
|
+
if isinstance(arg.keyword, cst.Name):
|
|
247
|
+
if arg.keyword.value == "load_from":
|
|
248
|
+
load_from_arg = arg
|
|
249
|
+
# Extract the value for comparison/warning
|
|
250
|
+
if isinstance(arg.value, cst.SimpleString):
|
|
251
|
+
load_from_value = arg.value.value
|
|
252
|
+
elif arg.keyword.value == "dump_to":
|
|
253
|
+
dump_to_arg = arg
|
|
254
|
+
# Extract the value for comparison/warning
|
|
255
|
+
if isinstance(arg.value, cst.SimpleString):
|
|
256
|
+
dump_to_value = arg.value.value
|
|
257
|
+
|
|
258
|
+
has_both_load_from_and_dump_to = load_from_arg is not None and dump_to_arg is not None
|
|
259
|
+
|
|
235
260
|
new_args = []
|
|
236
261
|
changed = False
|
|
237
262
|
param_mappings = {
|
|
@@ -245,6 +270,31 @@ class MarshmallowTransformer(BaseTransformer):
|
|
|
245
270
|
if isinstance(arg.keyword, cst.Name) and arg.keyword.value in param_mappings:
|
|
246
271
|
old_name = arg.keyword.value
|
|
247
272
|
new_name = param_mappings[old_name]
|
|
273
|
+
|
|
274
|
+
# Special case: skip dump_to when both load_from and dump_to exist
|
|
275
|
+
if old_name == "dump_to" and has_both_load_from_and_dump_to:
|
|
276
|
+
changed = True
|
|
277
|
+
# Record that dump_to was removed due to conflict
|
|
278
|
+
self.record_change(
|
|
279
|
+
description=(
|
|
280
|
+
f"Remove '{old_name}' parameter - Marshmallow 3.x uses single "
|
|
281
|
+
f"data_key for both load/dump. load_from value kept, dump_to="
|
|
282
|
+
f"{dump_to_value} removed. Manual review may be needed if "
|
|
283
|
+
f"load_from ({load_from_value}) != dump_to ({dump_to_value})."
|
|
284
|
+
),
|
|
285
|
+
line_number=1,
|
|
286
|
+
original=f"{func_name}(load_from=..., dump_to=...)",
|
|
287
|
+
replacement=f"{func_name}(data_key=...)",
|
|
288
|
+
transform_name="remove_dump_to_conflict",
|
|
289
|
+
notes=(
|
|
290
|
+
f"dump_to={dump_to_value} was removed because load_from="
|
|
291
|
+
f"{load_from_value} was also present. In Marshmallow 3.x, "
|
|
292
|
+
"data_key serves both purposes."
|
|
293
|
+
),
|
|
294
|
+
)
|
|
295
|
+
# Skip adding this arg
|
|
296
|
+
continue
|
|
297
|
+
|
|
248
298
|
new_arg = arg.with_changes(keyword=cst.Name(new_name))
|
|
249
299
|
new_args.append(new_arg)
|
|
250
300
|
changed = True
|
|
@@ -24,12 +24,151 @@ class PydanticV1ToV2Transformer(BaseTransformer):
|
|
|
24
24
|
self._current_class: str | None = None
|
|
25
25
|
# Track position info
|
|
26
26
|
self._line_offset = 0
|
|
27
|
+
# Track Pydantic model classes defined in this file
|
|
28
|
+
self._pydantic_model_classes: set[str] = set()
|
|
29
|
+
# Track variables known to be Pydantic model instances
|
|
30
|
+
self._pydantic_instance_vars: set[str] = set()
|
|
31
|
+
# Track function parameters with Pydantic model type hints
|
|
32
|
+
self._pydantic_param_vars: set[str] = set()
|
|
33
|
+
# Track if BaseModel is imported from pydantic
|
|
34
|
+
self._has_basemodel_import = False
|
|
35
|
+
|
|
36
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
37
|
+
"""Track Pydantic imports to identify model base classes."""
|
|
38
|
+
if node.module is None:
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
module_name = self._get_module_name(node.module)
|
|
42
|
+
if module_name == "pydantic" or module_name.startswith("pydantic."):
|
|
43
|
+
if isinstance(node.names, cst.ImportStar):
|
|
44
|
+
# With star import, assume BaseModel is available
|
|
45
|
+
self._has_basemodel_import = True
|
|
46
|
+
elif isinstance(node.names, tuple):
|
|
47
|
+
for name in node.names:
|
|
48
|
+
if isinstance(name, cst.ImportAlias):
|
|
49
|
+
imported_name = self._get_name_value(name.name)
|
|
50
|
+
if imported_name == "BaseModel":
|
|
51
|
+
self._has_basemodel_import = True
|
|
52
|
+
return True
|
|
27
53
|
|
|
28
54
|
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
|
29
|
-
"""Track the current class being visited."""
|
|
55
|
+
"""Track the current class being visited and detect Pydantic models."""
|
|
30
56
|
self._current_class = node.name.value
|
|
57
|
+
|
|
58
|
+
# Check if this class inherits from BaseModel or another known Pydantic model
|
|
59
|
+
for base in node.bases:
|
|
60
|
+
base_name = self._get_base_class_name(base.value)
|
|
61
|
+
if (
|
|
62
|
+
base_name in ("BaseModel", "pydantic.BaseModel")
|
|
63
|
+
or base_name in self._pydantic_model_classes
|
|
64
|
+
):
|
|
65
|
+
self._pydantic_model_classes.add(node.name.value)
|
|
66
|
+
break
|
|
67
|
+
return True
|
|
68
|
+
|
|
69
|
+
def _get_base_class_name(self, node: cst.BaseExpression) -> str:
|
|
70
|
+
"""Get the name of a base class from its AST node."""
|
|
71
|
+
if isinstance(node, cst.Name):
|
|
72
|
+
return node.value
|
|
73
|
+
if isinstance(node, cst.Attribute):
|
|
74
|
+
return f"{self._get_base_class_name(node.value)}.{node.attr.value}"
|
|
75
|
+
if isinstance(node, cst.Subscript):
|
|
76
|
+
# Handle Generic[T] style - get the base
|
|
77
|
+
return self._get_base_class_name(node.value)
|
|
78
|
+
return ""
|
|
79
|
+
|
|
80
|
+
def visit_Assign(self, node: cst.Assign) -> bool:
|
|
81
|
+
"""Track assignments of Pydantic model instances to variables."""
|
|
82
|
+
# Check if the value is a call to a Pydantic model class
|
|
83
|
+
if isinstance(node.value, cst.Call):
|
|
84
|
+
class_name = self._get_call_func_name(node.value.func)
|
|
85
|
+
if class_name in self._pydantic_model_classes:
|
|
86
|
+
# Track all assigned variable names
|
|
87
|
+
for target in node.targets:
|
|
88
|
+
if isinstance(target.target, cst.Name):
|
|
89
|
+
self._pydantic_instance_vars.add(target.target.value)
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
|
|
93
|
+
"""Track annotated assignments with Pydantic model type hints."""
|
|
94
|
+
if isinstance(node.target, cst.Name):
|
|
95
|
+
type_name = self._get_annotation_name(node.annotation.annotation)
|
|
96
|
+
if type_name in self._pydantic_model_classes:
|
|
97
|
+
self._pydantic_instance_vars.add(node.target.value)
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
|
101
|
+
"""Track function parameters with Pydantic model type annotations."""
|
|
102
|
+
for param in node.params.params:
|
|
103
|
+
if param.annotation is not None:
|
|
104
|
+
type_name = self._get_annotation_name(param.annotation.annotation)
|
|
105
|
+
if type_name in self._pydantic_model_classes:
|
|
106
|
+
self._pydantic_param_vars.add(param.name.value)
|
|
31
107
|
return True
|
|
32
108
|
|
|
109
|
+
def leave_FunctionDef_params(self, node: cst.FunctionDef) -> None:
|
|
110
|
+
"""Clear function-scoped parameter tracking when leaving function."""
|
|
111
|
+
# Note: This is a simplified approach - ideally we'd use proper scope analysis
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
def _get_call_func_name(self, node: cst.BaseExpression) -> str:
|
|
115
|
+
"""Get the function/class name from a Call's func attribute."""
|
|
116
|
+
if isinstance(node, cst.Name):
|
|
117
|
+
return node.value
|
|
118
|
+
if isinstance(node, cst.Attribute):
|
|
119
|
+
return node.attr.value # Return just the class name part
|
|
120
|
+
return ""
|
|
121
|
+
|
|
122
|
+
def _get_annotation_name(self, node: cst.BaseExpression) -> str:
|
|
123
|
+
"""Extract the type name from a type annotation."""
|
|
124
|
+
if isinstance(node, cst.Name):
|
|
125
|
+
return node.value
|
|
126
|
+
if isinstance(node, cst.Attribute):
|
|
127
|
+
return node.attr.value # Return just the class name part
|
|
128
|
+
if isinstance(node, cst.Subscript):
|
|
129
|
+
# Handle Optional[Model], List[Model], etc.
|
|
130
|
+
return self._get_annotation_name(node.value)
|
|
131
|
+
return ""
|
|
132
|
+
|
|
133
|
+
def _is_pydantic_instance(self, node: cst.BaseExpression) -> bool:
|
|
134
|
+
"""Check if an expression is known to be a Pydantic model instance.
|
|
135
|
+
|
|
136
|
+
Returns True if we can confirm it's a Pydantic instance.
|
|
137
|
+
Returns False if we cannot confirm (either unknown or definitely not Pydantic).
|
|
138
|
+
"""
|
|
139
|
+
if isinstance(node, cst.Name):
|
|
140
|
+
var_name = node.value
|
|
141
|
+
# Check if it's a known Pydantic instance variable
|
|
142
|
+
if var_name in self._pydantic_instance_vars:
|
|
143
|
+
return True
|
|
144
|
+
# Check if it's a function parameter with Pydantic type hint
|
|
145
|
+
if var_name in self._pydantic_param_vars:
|
|
146
|
+
return True
|
|
147
|
+
# Heuristic: variable name matches a model class name (case-insensitive)
|
|
148
|
+
for model_class in self._pydantic_model_classes:
|
|
149
|
+
if var_name.lower() == model_class.lower():
|
|
150
|
+
return True
|
|
151
|
+
return False
|
|
152
|
+
if isinstance(node, cst.Call):
|
|
153
|
+
# Direct call like Model().json() - check if the function is a Pydantic class
|
|
154
|
+
func_name = self._get_call_func_name(node.func)
|
|
155
|
+
return func_name in self._pydantic_model_classes
|
|
156
|
+
if isinstance(node, cst.Attribute):
|
|
157
|
+
# Could be accessing an attribute that returns a Pydantic model
|
|
158
|
+
# This is harder to determine without full type analysis
|
|
159
|
+
return False
|
|
160
|
+
return False
|
|
161
|
+
|
|
162
|
+
def _is_class_method_call(self, node: cst.BaseExpression) -> bool:
|
|
163
|
+
"""Check if this is a call on a class rather than an instance (e.g., Model.parse_obj).
|
|
164
|
+
|
|
165
|
+
Class methods like parse_obj, schema, etc. are called on the class itself.
|
|
166
|
+
"""
|
|
167
|
+
if isinstance(node, cst.Name):
|
|
168
|
+
# Check if the name is a known Pydantic model class
|
|
169
|
+
return node.value in self._pydantic_model_classes
|
|
170
|
+
return False
|
|
171
|
+
|
|
33
172
|
def leave_ClassDef(
|
|
34
173
|
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
35
174
|
) -> cst.ClassDef:
|
|
@@ -372,11 +511,17 @@ class PydanticV1ToV2Transformer(BaseTransformer):
|
|
|
372
511
|
# Handle method calls on objects
|
|
373
512
|
if isinstance(updated_node.func, cst.Attribute):
|
|
374
513
|
method_name = updated_node.func.attr.value
|
|
514
|
+
obj = updated_node.func.value
|
|
375
515
|
|
|
376
|
-
|
|
516
|
+
# Methods that can only be called on instances
|
|
517
|
+
instance_method_mappings = {
|
|
377
518
|
"dict": "model_dump",
|
|
378
519
|
"json": "model_dump_json",
|
|
379
520
|
"copy": "model_copy",
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
# Methods that are typically called on the class (class methods)
|
|
524
|
+
class_method_mappings = {
|
|
380
525
|
"parse_obj": "model_validate",
|
|
381
526
|
"parse_raw": "model_validate_json",
|
|
382
527
|
"schema": "model_json_schema",
|
|
@@ -384,19 +529,40 @@ class PydanticV1ToV2Transformer(BaseTransformer):
|
|
|
384
529
|
"update_forward_refs": "model_rebuild",
|
|
385
530
|
}
|
|
386
531
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
532
|
+
# Handle instance methods - need to verify the object is a Pydantic instance
|
|
533
|
+
if method_name in instance_method_mappings:
|
|
534
|
+
# Only transform if we can confirm this is a Pydantic model instance
|
|
535
|
+
if self._is_pydantic_instance(obj):
|
|
536
|
+
new_method = instance_method_mappings[method_name]
|
|
537
|
+
new_attr = updated_node.func.with_changes(attr=cst.Name(new_method))
|
|
390
538
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
539
|
+
self.record_change(
|
|
540
|
+
description=f"Convert .{method_name}() to .{new_method}()",
|
|
541
|
+
line_number=1,
|
|
542
|
+
original=f".{method_name}()",
|
|
543
|
+
replacement=f".{new_method}()",
|
|
544
|
+
transform_name=f"{method_name}_to_{new_method}",
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
return updated_node.with_changes(func=new_attr)
|
|
548
|
+
# If we can't confirm it's a Pydantic instance, skip transformation
|
|
549
|
+
# This prevents false positives like response.json() on requests.Response
|
|
398
550
|
|
|
399
|
-
|
|
551
|
+
# Handle class methods - verify the object is a Pydantic model class
|
|
552
|
+
if method_name in class_method_mappings:
|
|
553
|
+
if self._is_class_method_call(obj):
|
|
554
|
+
new_method = class_method_mappings[method_name]
|
|
555
|
+
new_attr = updated_node.func.with_changes(attr=cst.Name(new_method))
|
|
556
|
+
|
|
557
|
+
self.record_change(
|
|
558
|
+
description=f"Convert .{method_name}() to .{new_method}()",
|
|
559
|
+
line_number=1,
|
|
560
|
+
original=f".{method_name}()",
|
|
561
|
+
replacement=f".{new_method}()",
|
|
562
|
+
transform_name=f"{method_name}_to_{new_method}",
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
return updated_node.with_changes(func=new_attr)
|
|
400
566
|
|
|
401
567
|
# Handle Field(regex=...) -> Field(pattern=...)
|
|
402
568
|
if isinstance(updated_node.func, cst.Name) and updated_node.func.value == "Field":
|
|
@@ -461,17 +627,20 @@ class PydanticV1ToV2Transformer(BaseTransformer):
|
|
|
461
627
|
}
|
|
462
628
|
|
|
463
629
|
if attr_name in attr_mappings:
|
|
464
|
-
|
|
630
|
+
# Only transform if the object is a known Pydantic model class
|
|
631
|
+
obj = updated_node.value
|
|
632
|
+
if self._is_class_method_call(obj):
|
|
633
|
+
new_attr = attr_mappings[attr_name]
|
|
465
634
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
635
|
+
self.record_change(
|
|
636
|
+
description=f"Convert {attr_name} to {new_attr}",
|
|
637
|
+
line_number=1,
|
|
638
|
+
original=attr_name,
|
|
639
|
+
replacement=new_attr,
|
|
640
|
+
transform_name=f"{attr_name}_rename",
|
|
641
|
+
)
|
|
473
642
|
|
|
474
|
-
|
|
643
|
+
return updated_node.with_changes(attr=cst.Name(new_attr))
|
|
475
644
|
|
|
476
645
|
return updated_node
|
|
477
646
|
|
|
@@ -6,6 +6,17 @@ from pathlib import Path
|
|
|
6
6
|
import libcst as cst
|
|
7
7
|
from libcst.metadata import MetadataWrapper, PositionProvider
|
|
8
8
|
|
|
9
|
+
# Mapping of package names to their actual import names
|
|
10
|
+
# Some packages have different import names than their package names
|
|
11
|
+
PACKAGE_IMPORT_ALIASES: dict[str, list[str]] = {
|
|
12
|
+
"attrs": ["attr", "attrs"], # attrs package can be imported as "attr" or "attrs"
|
|
13
|
+
"pillow": ["PIL"], # pillow package is imported as PIL
|
|
14
|
+
"scikit-learn": ["sklearn"], # scikit-learn is imported as sklearn
|
|
15
|
+
"beautifulsoup4": ["bs4"], # beautifulsoup4 is imported as bs4
|
|
16
|
+
"pyyaml": ["yaml"], # pyyaml is imported as yaml
|
|
17
|
+
"python-dateutil": ["dateutil"], # python-dateutil is imported as dateutil
|
|
18
|
+
}
|
|
19
|
+
|
|
9
20
|
|
|
10
21
|
@dataclass
|
|
11
22
|
class ImportInfo:
|
|
@@ -53,15 +64,24 @@ class ImportVisitor(cst.CSTVisitor):
|
|
|
53
64
|
|
|
54
65
|
def __init__(self, target_library: str):
|
|
55
66
|
self.target_library = target_library
|
|
67
|
+
# Get all possible import names for this library
|
|
68
|
+
self.import_names = PACKAGE_IMPORT_ALIASES.get(target_library.lower(), [target_library])
|
|
56
69
|
self.imports: list[ImportInfo] = []
|
|
57
70
|
self._imported_names: set[str] = set()
|
|
58
71
|
|
|
72
|
+
def _matches_target_library(self, module_name: str) -> bool:
|
|
73
|
+
"""Check if a module name matches the target library or its aliases."""
|
|
74
|
+
for import_name in self.import_names:
|
|
75
|
+
if module_name == import_name or module_name.startswith(f"{import_name}."):
|
|
76
|
+
return True
|
|
77
|
+
return False
|
|
78
|
+
|
|
59
79
|
def visit_Import(self, node: cst.Import) -> None:
|
|
60
80
|
"""Visit import statements like 'import pydantic'."""
|
|
61
81
|
for name in node.names if isinstance(node.names, tuple) else []:
|
|
62
82
|
if isinstance(name, cst.ImportAlias):
|
|
63
83
|
module_name = self._get_name_value(name.name)
|
|
64
|
-
if module_name and
|
|
84
|
+
if module_name and self._matches_target_library(module_name):
|
|
65
85
|
alias = None
|
|
66
86
|
if name.asname and isinstance(name.asname, cst.AsName):
|
|
67
87
|
alias = self._get_name_value(name.asname.name)
|
|
@@ -84,7 +104,7 @@ class ImportVisitor(cst.CSTVisitor):
|
|
|
84
104
|
return
|
|
85
105
|
|
|
86
106
|
module_name = self._get_name_value(node.module)
|
|
87
|
-
if not module_name or not
|
|
107
|
+
if not module_name or not self._matches_target_library(module_name):
|
|
88
108
|
return
|
|
89
109
|
|
|
90
110
|
names = []
|
codeshift/utils/api_client.py
CHANGED
|
@@ -4,12 +4,109 @@ This client calls the Codeshift API instead of Anthropic directly,
|
|
|
4
4
|
ensuring that LLM features are gated behind the subscription model.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
import logging
|
|
7
8
|
from dataclasses import dataclass
|
|
9
|
+
from urllib.parse import urlparse
|
|
8
10
|
|
|
9
11
|
import httpx
|
|
10
12
|
|
|
11
13
|
from codeshift.cli.commands.auth import get_api_key, get_api_url
|
|
12
14
|
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InsecureURLError(Exception):
|
|
19
|
+
"""Raised when an insecure (non-HTTPS) URL is used for API communication.
|
|
20
|
+
|
|
21
|
+
HTTPS is required to protect API keys and sensitive data in transit.
|
|
22
|
+
Man-in-the-middle attacks could intercept API keys if HTTP is used.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, url: str, message: str | None = None):
|
|
26
|
+
self.url = url
|
|
27
|
+
default_msg = (
|
|
28
|
+
f"Insecure URL detected: {url}. "
|
|
29
|
+
"HTTPS is required for API communication to protect your API key. "
|
|
30
|
+
"Use HTTPS or set CODESHIFT_ALLOW_INSECURE=true for local development only."
|
|
31
|
+
)
|
|
32
|
+
super().__init__(message or default_msg)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def validate_api_url(url: str) -> str:
|
|
36
|
+
"""Validate and normalize the API URL.
|
|
37
|
+
|
|
38
|
+
Enforces HTTPS for all non-localhost hosts to prevent API key interception.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
url: The API URL to validate
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The validated and normalized URL
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
InsecureURLError: If the URL uses HTTP for a non-localhost host
|
|
48
|
+
ValueError: If the URL is malformed
|
|
49
|
+
"""
|
|
50
|
+
if not url:
|
|
51
|
+
raise ValueError("API URL cannot be empty")
|
|
52
|
+
|
|
53
|
+
# Parse the URL
|
|
54
|
+
try:
|
|
55
|
+
parsed = urlparse(url)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
raise ValueError(f"Malformed URL: {url}") from e
|
|
58
|
+
|
|
59
|
+
if not parsed.scheme:
|
|
60
|
+
raise ValueError(f"URL must include a scheme (http/https): {url}")
|
|
61
|
+
|
|
62
|
+
if not parsed.netloc:
|
|
63
|
+
raise ValueError(f"URL must include a host: {url}")
|
|
64
|
+
|
|
65
|
+
# Define localhost patterns
|
|
66
|
+
localhost_patterns = (
|
|
67
|
+
"localhost",
|
|
68
|
+
"127.0.0.1",
|
|
69
|
+
"::1",
|
|
70
|
+
"0.0.0.0",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
host = parsed.hostname or ""
|
|
74
|
+
is_localhost = any(
|
|
75
|
+
host == pattern or host.startswith(f"{pattern}:") for pattern in localhost_patterns
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Allow HTTP only for localhost (development)
|
|
79
|
+
if parsed.scheme == "http":
|
|
80
|
+
# Check for explicit override (development only)
|
|
81
|
+
import os
|
|
82
|
+
|
|
83
|
+
allow_insecure = os.environ.get("CODESHIFT_ALLOW_INSECURE", "").lower() == "true"
|
|
84
|
+
|
|
85
|
+
if is_localhost:
|
|
86
|
+
logger.warning(
|
|
87
|
+
"Using HTTP for localhost development. " "This should not be used in production."
|
|
88
|
+
)
|
|
89
|
+
elif allow_insecure:
|
|
90
|
+
logger.warning(
|
|
91
|
+
"SECURITY WARNING: CODESHIFT_ALLOW_INSECURE is set. "
|
|
92
|
+
"HTTP is being used for API communication. "
|
|
93
|
+
"Your API key may be exposed to network interception. "
|
|
94
|
+
"This should ONLY be used for local testing."
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
raise InsecureURLError(
|
|
98
|
+
url,
|
|
99
|
+
f"HTTP is not allowed for non-localhost hosts: {host}. "
|
|
100
|
+
"Use HTTPS to protect your API key from interception.",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Validate HTTPS URLs
|
|
104
|
+
if parsed.scheme not in ("http", "https"):
|
|
105
|
+
raise ValueError(f"URL scheme must be http or https, got: {parsed.scheme}")
|
|
106
|
+
|
|
107
|
+
# Remove trailing slash for consistency
|
|
108
|
+
return url.rstrip("/")
|
|
109
|
+
|
|
13
110
|
|
|
14
111
|
@dataclass
|
|
15
112
|
class APIResponse:
|
|
@@ -30,24 +127,46 @@ class CodeshiftAPIClient:
|
|
|
30
127
|
- Authentication and authorization
|
|
31
128
|
- Quota checking and billing
|
|
32
129
|
- Server-side Anthropic API calls
|
|
130
|
+
|
|
131
|
+
Security features:
|
|
132
|
+
- HTTPS enforcement for all non-localhost URLs
|
|
133
|
+
- API key protection via secure headers
|
|
134
|
+
- SSL verification enabled by default
|
|
33
135
|
"""
|
|
34
136
|
|
|
35
137
|
def __init__(
|
|
36
138
|
self,
|
|
37
139
|
api_key: str | None = None,
|
|
38
140
|
api_url: str | None = None,
|
|
39
|
-
timeout: int =
|
|
141
|
+
timeout: int = 180,
|
|
142
|
+
verify_ssl: bool = True,
|
|
40
143
|
):
|
|
41
144
|
"""Initialize the API client.
|
|
42
145
|
|
|
43
146
|
Args:
|
|
44
147
|
api_key: Codeshift API key. Defaults to stored credentials.
|
|
45
148
|
api_url: API base URL. Defaults to stored URL.
|
|
46
|
-
timeout: Request timeout in seconds.
|
|
149
|
+
timeout: Request timeout in seconds (default 180 for LLM calls).
|
|
150
|
+
verify_ssl: Whether to verify SSL certificates (default True).
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
InsecureURLError: If the URL uses HTTP for a non-localhost host.
|
|
47
154
|
"""
|
|
48
155
|
self.api_key = api_key or get_api_key()
|
|
49
|
-
|
|
156
|
+
|
|
157
|
+
# Validate and normalize the API URL
|
|
158
|
+
raw_url = api_url or get_api_url()
|
|
159
|
+
self.api_url = validate_api_url(raw_url)
|
|
160
|
+
|
|
50
161
|
self.timeout = timeout
|
|
162
|
+
self.verify_ssl = verify_ssl
|
|
163
|
+
|
|
164
|
+
# Log SSL verification status
|
|
165
|
+
if not verify_ssl:
|
|
166
|
+
logger.warning(
|
|
167
|
+
"SSL verification is disabled. "
|
|
168
|
+
"This exposes the connection to man-in-the-middle attacks."
|
|
169
|
+
)
|
|
51
170
|
|
|
52
171
|
@property
|
|
53
172
|
def is_available(self) -> bool:
|
|
@@ -76,9 +195,13 @@ class CodeshiftAPIClient:
|
|
|
76
195
|
|
|
77
196
|
return httpx.post(
|
|
78
197
|
f"{self.api_url}{endpoint}",
|
|
79
|
-
headers={
|
|
198
|
+
headers={
|
|
199
|
+
"X-API-Key": self.api_key,
|
|
200
|
+
"Content-Type": "application/json",
|
|
201
|
+
},
|
|
80
202
|
json=payload,
|
|
81
203
|
timeout=self.timeout,
|
|
204
|
+
verify=self.verify_ssl,
|
|
82
205
|
)
|
|
83
206
|
|
|
84
207
|
def migrate_code(
|
|
@@ -157,6 +280,15 @@ class CodeshiftAPIClient:
|
|
|
157
280
|
error="LLM migrations require Pro tier or higher. Run 'codeshift upgrade-plan' to upgrade.",
|
|
158
281
|
)
|
|
159
282
|
|
|
283
|
+
elif response.status_code == 429:
|
|
284
|
+
# Rate limited
|
|
285
|
+
retry_after = response.headers.get("Retry-After", "60")
|
|
286
|
+
return APIResponse(
|
|
287
|
+
success=False,
|
|
288
|
+
content=code,
|
|
289
|
+
error=f"Rate limited. Please wait {retry_after} seconds before retrying.",
|
|
290
|
+
)
|
|
291
|
+
|
|
160
292
|
elif response.status_code == 503:
|
|
161
293
|
return APIResponse(
|
|
162
294
|
success=False,
|
|
@@ -233,6 +365,14 @@ class CodeshiftAPIClient:
|
|
|
233
365
|
error="This feature requires Pro tier or higher.",
|
|
234
366
|
)
|
|
235
367
|
|
|
368
|
+
elif response.status_code == 429:
|
|
369
|
+
retry_after = response.headers.get("Retry-After", "60")
|
|
370
|
+
return APIResponse(
|
|
371
|
+
success=False,
|
|
372
|
+
content="",
|
|
373
|
+
error=f"Rate limited. Please wait {retry_after} seconds before retrying.",
|
|
374
|
+
)
|
|
375
|
+
|
|
236
376
|
else:
|
|
237
377
|
return APIResponse(
|
|
238
378
|
success=False,
|