323 lines
11 KiB
Python
323 lines
11 KiB
Python
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Dict, List, Optional, TypedDict
|
|
|
|
import boto3
|
|
from botocore.exceptions import BotoCoreError, ClientError
|
|
|
|
|
|
class ConfigError(Exception):
|
|
pass
|
|
|
|
|
|
class AWSOperationError(Exception):
|
|
pass
|
|
|
|
|
|
class InstanceSpec(TypedDict, total=False):
|
|
instance_type: Optional[str]
|
|
instance_name: Optional[str]
|
|
root_device: Optional[str]
|
|
root_size: Optional[int]
|
|
root_volume_type: Optional[str]
|
|
security_group_ids: List[str]
|
|
security_group_names: List[str]
|
|
subnet_id: Optional[str]
|
|
availability_zone: Optional[str]
|
|
region: Optional[str]
|
|
|
|
|
|
@dataclass
|
|
class AccountConfig:
|
|
name: str
|
|
region: str
|
|
access_key_id: str
|
|
secret_access_key: str
|
|
ami_id: str
|
|
subnet_id: Optional[str] = None
|
|
security_group_ids: List[str] = field(default_factory=list)
|
|
key_name: Optional[str] = None
|
|
|
|
|
|
def ec2_client(account: AccountConfig):
|
|
return boto3.client(
|
|
"ec2",
|
|
region_name=account.region,
|
|
aws_access_key_id=account.access_key_id,
|
|
aws_secret_access_key=account.secret_access_key,
|
|
)
|
|
|
|
|
|
def cloudwatch_client(account: AccountConfig):
|
|
return boto3.client(
|
|
"cloudwatch",
|
|
region_name=account.region,
|
|
aws_access_key_id=account.access_key_id,
|
|
aws_secret_access_key=account.secret_access_key,
|
|
)
|
|
|
|
|
|
def _get_instance_by_ip(client, ip: str) -> Optional[dict]:
|
|
filters = [
|
|
{"Name": "instance-state-name", "Values": ["pending", "running", "stopping", "stopped"]},
|
|
]
|
|
for field in ["ip-address", "private-ip-address"]:
|
|
try:
|
|
resp = client.describe_instances(Filters=filters + [{"Name": field, "Values": [ip]}])
|
|
except (ClientError, BotoCoreError) as exc:
|
|
raise AWSOperationError(f"Failed to describe instances: {exc}") from exc
|
|
|
|
for reservation in resp.get("Reservations", []):
|
|
for instance in reservation.get("Instances", []):
|
|
return instance
|
|
return None
|
|
|
|
|
|
def _wait_for_state(client, instance_id: str, waiter_name: str) -> None:
|
|
waiter = client.get_waiter(waiter_name)
|
|
waiter.wait(InstanceIds=[instance_id])
|
|
|
|
|
|
def _get_root_volume_spec(client, instance: dict) -> tuple[Optional[str], Optional[int], Optional[str]]:
|
|
"""Return (device_name, size_gb, volume_type) for root volume if available."""
|
|
root_device_name = instance.get("RootDeviceName")
|
|
if not root_device_name:
|
|
return None, None, None
|
|
|
|
for mapping in instance.get("BlockDeviceMappings", []):
|
|
if mapping.get("DeviceName") != root_device_name:
|
|
continue
|
|
ebs = mapping.get("Ebs")
|
|
if not ebs:
|
|
return root_device_name, None, None
|
|
volume_id = ebs.get("VolumeId")
|
|
if not volume_id:
|
|
return root_device_name, None, None
|
|
try:
|
|
vol_resp = client.describe_volumes(VolumeIds=[volume_id])
|
|
volumes = vol_resp.get("Volumes", [])
|
|
if volumes:
|
|
volume = volumes[0]
|
|
return root_device_name, volume.get("Size"), volume.get("VolumeType")
|
|
except (ClientError, BotoCoreError) as exc:
|
|
raise AWSOperationError(f"Failed to read volume info for {volume_id}: {exc}") from exc
|
|
return root_device_name, None, None
|
|
|
|
|
|
def _extract_security_group_ids(instance: dict) -> List[str]:
|
|
groups = []
|
|
for g in instance.get("SecurityGroups", []):
|
|
gid = g.get("GroupId")
|
|
if gid:
|
|
groups.append(gid)
|
|
return groups
|
|
|
|
|
|
def _extract_security_group_names(instance: dict) -> List[str]:
|
|
groups = []
|
|
for g in instance.get("SecurityGroups", []):
|
|
name = g.get("GroupName")
|
|
if name:
|
|
groups.append(name)
|
|
return groups
|
|
|
|
|
|
def _extract_name_tag(instance: dict) -> Optional[str]:
|
|
for tag in instance.get("Tags", []) or []:
|
|
if tag.get("Key") == "Name":
|
|
return tag.get("Value")
|
|
return None
|
|
|
|
|
|
def _terminate_instance(client, instance_id: str, wait_for_completion: bool = True) -> None:
|
|
try:
|
|
client.terminate_instances(InstanceIds=[instance_id])
|
|
if wait_for_completion:
|
|
_wait_for_state(client, instance_id, "instance_terminated")
|
|
except (ClientError, BotoCoreError) as exc:
|
|
raise AWSOperationError(f"Failed to terminate instance {instance_id}: {exc}") from exc
|
|
|
|
|
|
def _build_block_device_mappings(
|
|
device_name: Optional[str], volume_size: Optional[int], volume_type: Optional[str]
|
|
) -> Optional[list]:
|
|
if not device_name:
|
|
return None
|
|
ebs = {"DeleteOnTermination": True}
|
|
if volume_type:
|
|
ebs["VolumeType"] = volume_type
|
|
if volume_size:
|
|
ebs["VolumeSize"] = volume_size
|
|
return [{"DeviceName": device_name, "Ebs": ebs}]
|
|
|
|
|
|
def _provision_instance(
|
|
client,
|
|
account: AccountConfig,
|
|
spec: InstanceSpec,
|
|
) -> str:
|
|
def _build_params(include_key: bool = True) -> dict:
|
|
params = {
|
|
"ImageId": account.ami_id,
|
|
"InstanceType": spec.get("instance_type"),
|
|
"MinCount": 1,
|
|
"MaxCount": 1,
|
|
}
|
|
if spec.get("instance_name"):
|
|
params["TagSpecifications"] = [
|
|
{
|
|
"ResourceType": "instance",
|
|
"Tags": [{"Key": "Name", "Value": spec["instance_name"]}],
|
|
}
|
|
]
|
|
subnet_id = spec.get("subnet_id")
|
|
if subnet_id:
|
|
params["SubnetId"] = subnet_id
|
|
security_group_ids = spec.get("security_group_ids")
|
|
if security_group_ids:
|
|
params["SecurityGroupIds"] = security_group_ids
|
|
block_mapping = _build_block_device_mappings(
|
|
spec.get("root_device"), spec.get("root_size"), spec.get("root_volume_type")
|
|
)
|
|
if block_mapping:
|
|
params["BlockDeviceMappings"] = block_mapping
|
|
if include_key and account.key_name:
|
|
params["KeyName"] = account.key_name
|
|
return params
|
|
|
|
def _run(params: dict) -> str:
|
|
resp = client.run_instances(**params)
|
|
instance_id = resp["Instances"][0]["InstanceId"]
|
|
_wait_for_state(client, instance_id, "instance_running")
|
|
return instance_id
|
|
|
|
try:
|
|
return _run(_build_params())
|
|
except ClientError as exc:
|
|
code = exc.response.get("Error", {}).get("Code") if hasattr(exc, "response") else None
|
|
if code == "InvalidKeyPair.NotFound" and account.key_name:
|
|
# fallback: retry without key pair
|
|
try:
|
|
return _run(_build_params(include_key=False))
|
|
except (ClientError, BotoCoreError) as exc2:
|
|
raise AWSOperationError(
|
|
f"Failed to create instance after removing missing key pair {account.key_name}: {exc2}"
|
|
) from exc
|
|
raise AWSOperationError(f"Failed to create instance: {exc}") from exc
|
|
except BotoCoreError as exc:
|
|
raise AWSOperationError(f"Failed to create instance: {exc}") from exc
|
|
|
|
|
|
def _get_public_ip(client, instance_id: str) -> str:
|
|
try:
|
|
resp = client.describe_instances(InstanceIds=[instance_id])
|
|
reservations = resp.get("Reservations", [])
|
|
if not reservations:
|
|
raise AWSOperationError("Instance not found when reading IP")
|
|
instance = reservations[0]["Instances"][0]
|
|
return instance.get("PublicIpAddress") or ""
|
|
except (ClientError, BotoCoreError) as exc:
|
|
raise AWSOperationError(f"Failed to fetch public IP: {exc}") from exc
|
|
|
|
|
|
def _recycle_ip_until_free(client, instance_id: str, banned_ips: set[str], retry_limit: int) -> str:
|
|
attempts = 0
|
|
while attempts < retry_limit:
|
|
current_ip = _get_public_ip(client, instance_id)
|
|
if current_ip and current_ip not in banned_ips:
|
|
return current_ip
|
|
try:
|
|
client.stop_instances(InstanceIds=[instance_id])
|
|
_wait_for_state(client, instance_id, "instance_stopped")
|
|
client.start_instances(InstanceIds=[instance_id])
|
|
_wait_for_state(client, instance_id, "instance_running")
|
|
except (ClientError, BotoCoreError) as exc:
|
|
raise AWSOperationError(f"Failed while cycling IP: {exc}") from exc
|
|
attempts += 1
|
|
raise AWSOperationError("Reached retry limit while attempting to obtain a free IP")
|
|
|
|
|
|
def _get_network_out_mb(cw_client, instance_id: str, days: int = 30) -> float:
|
|
"""Fetch total NetworkOut over the past window (MB)."""
|
|
end = datetime.now(timezone.utc)
|
|
start = end - timedelta(days=days)
|
|
try:
|
|
resp = cw_client.get_metric_statistics(
|
|
Namespace="AWS/EC2",
|
|
MetricName="NetworkOut",
|
|
Dimensions=[{"Name": "InstanceId", "Value": instance_id}],
|
|
StartTime=start,
|
|
EndTime=end,
|
|
Period=3600 * 6, # 6 小时粒度,覆盖 30 天
|
|
Statistics=["Sum"],
|
|
)
|
|
datapoints = resp.get("Datapoints", [])
|
|
if not datapoints:
|
|
return 0.0
|
|
total_bytes = sum(dp.get("Sum", 0.0) for dp in datapoints)
|
|
return round(total_bytes / (1024 * 1024), 2)
|
|
except (ClientError, BotoCoreError) as exc:
|
|
raise AWSOperationError(f"Failed to fetch NetworkOut metrics: {exc}") from exc
|
|
|
|
|
|
def _build_spec_from_instance(client, instance: dict, account: AccountConfig) -> InstanceSpec:
|
|
instance_type = instance.get("InstanceType")
|
|
if not instance_type:
|
|
raise AWSOperationError("Failed to detect instance type from source instance")
|
|
root_device, root_size, root_volume_type = _get_root_volume_spec(client, instance)
|
|
return {
|
|
"instance_type": instance_type,
|
|
"instance_name": _extract_name_tag(instance),
|
|
"root_device": root_device,
|
|
"root_size": root_size,
|
|
"root_volume_type": root_volume_type,
|
|
"security_group_ids": _extract_security_group_ids(instance),
|
|
"security_group_names": _extract_security_group_names(instance),
|
|
"subnet_id": instance.get("SubnetId") or account.subnet_id,
|
|
"availability_zone": instance.get("Placement", {}).get("AvailabilityZone"),
|
|
"region": account.region,
|
|
}
|
|
|
|
|
|
def replace_instance_ip(
|
|
ip: str,
|
|
account: AccountConfig,
|
|
disallowed_ips: set[str],
|
|
retry_limit: int = 5,
|
|
fallback_spec: Optional[InstanceSpec] = None,
|
|
) -> Dict[str, object]:
|
|
client = ec2_client(account)
|
|
cw = cloudwatch_client(account)
|
|
instance = _get_instance_by_ip(client, ip)
|
|
|
|
spec: Optional[InstanceSpec] = None
|
|
instance_id: Optional[str] = None
|
|
network_out_mb: Optional[float] = None
|
|
if instance:
|
|
instance_id = instance["InstanceId"]
|
|
spec = _build_spec_from_instance(client, instance, account)
|
|
try:
|
|
network_out_mb = _get_network_out_mb(cw, instance_id)
|
|
except AWSOperationError:
|
|
network_out_mb = None
|
|
elif fallback_spec:
|
|
spec = fallback_spec
|
|
|
|
if not spec:
|
|
raise AWSOperationError(f"No instance found with IP {ip} 且数据库无该IP规格信息")
|
|
|
|
new_instance_id = _provision_instance(client, account, spec)
|
|
|
|
new_ip = _recycle_ip_until_free(client, new_instance_id, disallowed_ips, retry_limit)
|
|
if instance_id:
|
|
# 不阻塞新实例创建,终止旧实例但不等待完成
|
|
_terminate_instance(client, instance_id, wait_for_completion=False)
|
|
|
|
return {
|
|
"terminated_instance_id": instance_id,
|
|
"new_instance_id": new_instance_id,
|
|
"new_ip": new_ip,
|
|
"spec_used": spec,
|
|
"terminated_network_out_mb": network_out_mb,
|
|
}
|