aws-mt5/aws_service.py

323 lines
11 KiB
Python
Raw Permalink Normal View History

2026-01-05 11:07:55 +08:00
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, TypedDict
2026-01-04 18:58:20 +08:00
import boto3
from botocore.exceptions import BotoCoreError, ClientError
class ConfigError(Exception):
pass
class AWSOperationError(Exception):
pass
2026-01-05 11:07:55 +08:00
class InstanceSpec(TypedDict, total=False):
instance_type: Optional[str]
instance_name: Optional[str]
root_device: Optional[str]
root_size: Optional[int]
root_volume_type: Optional[str]
security_group_ids: List[str]
security_group_names: List[str]
subnet_id: Optional[str]
availability_zone: Optional[str]
region: Optional[str]
2026-01-04 18:58:20 +08:00
@dataclass
class AccountConfig:
name: str
region: str
access_key_id: str
secret_access_key: str
ami_id: str
2026-01-05 11:07:55 +08:00
subnet_id: Optional[str] = None
security_group_ids: List[str] = field(default_factory=list)
2026-01-04 18:58:20 +08:00
key_name: Optional[str] = None
def ec2_client(account: AccountConfig):
return boto3.client(
"ec2",
region_name=account.region,
aws_access_key_id=account.access_key_id,
aws_secret_access_key=account.secret_access_key,
)
2026-01-05 11:07:55 +08:00
def cloudwatch_client(account: AccountConfig):
return boto3.client(
"cloudwatch",
region_name=account.region,
aws_access_key_id=account.access_key_id,
aws_secret_access_key=account.secret_access_key,
)
def _get_instance_by_ip(client, ip: str) -> Optional[dict]:
2026-01-04 18:58:20 +08:00
filters = [
{"Name": "instance-state-name", "Values": ["pending", "running", "stopping", "stopped"]},
]
for field in ["ip-address", "private-ip-address"]:
try:
resp = client.describe_instances(Filters=filters + [{"Name": field, "Values": [ip]}])
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to describe instances: {exc}") from exc
for reservation in resp.get("Reservations", []):
for instance in reservation.get("Instances", []):
2026-01-05 11:07:55 +08:00
return instance
2026-01-04 18:58:20 +08:00
return None
def _wait_for_state(client, instance_id: str, waiter_name: str) -> None:
waiter = client.get_waiter(waiter_name)
waiter.wait(InstanceIds=[instance_id])
2026-01-05 11:07:55 +08:00
def _get_root_volume_spec(client, instance: dict) -> tuple[Optional[str], Optional[int], Optional[str]]:
"""Return (device_name, size_gb, volume_type) for root volume if available."""
root_device_name = instance.get("RootDeviceName")
if not root_device_name:
return None, None, None
for mapping in instance.get("BlockDeviceMappings", []):
if mapping.get("DeviceName") != root_device_name:
continue
ebs = mapping.get("Ebs")
if not ebs:
return root_device_name, None, None
volume_id = ebs.get("VolumeId")
if not volume_id:
return root_device_name, None, None
try:
vol_resp = client.describe_volumes(VolumeIds=[volume_id])
volumes = vol_resp.get("Volumes", [])
if volumes:
volume = volumes[0]
return root_device_name, volume.get("Size"), volume.get("VolumeType")
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to read volume info for {volume_id}: {exc}") from exc
return root_device_name, None, None
def _extract_security_group_ids(instance: dict) -> List[str]:
groups = []
for g in instance.get("SecurityGroups", []):
gid = g.get("GroupId")
if gid:
groups.append(gid)
return groups
def _extract_security_group_names(instance: dict) -> List[str]:
groups = []
for g in instance.get("SecurityGroups", []):
name = g.get("GroupName")
if name:
groups.append(name)
return groups
def _extract_name_tag(instance: dict) -> Optional[str]:
for tag in instance.get("Tags", []) or []:
if tag.get("Key") == "Name":
return tag.get("Value")
return None
def _terminate_instance(client, instance_id: str, wait_for_completion: bool = True) -> None:
2026-01-04 18:58:20 +08:00
try:
client.terminate_instances(InstanceIds=[instance_id])
2026-01-05 11:07:55 +08:00
if wait_for_completion:
_wait_for_state(client, instance_id, "instance_terminated")
2026-01-04 18:58:20 +08:00
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to terminate instance {instance_id}: {exc}") from exc
2026-01-05 11:07:55 +08:00
def _build_block_device_mappings(
device_name: Optional[str], volume_size: Optional[int], volume_type: Optional[str]
) -> Optional[list]:
if not device_name:
return None
ebs = {"DeleteOnTermination": True}
if volume_type:
ebs["VolumeType"] = volume_type
if volume_size:
ebs["VolumeSize"] = volume_size
return [{"DeviceName": device_name, "Ebs": ebs}]
def _provision_instance(
client,
account: AccountConfig,
spec: InstanceSpec,
) -> str:
def _build_params(include_key: bool = True) -> dict:
2026-01-04 18:58:20 +08:00
params = {
"ImageId": account.ami_id,
2026-01-05 11:07:55 +08:00
"InstanceType": spec.get("instance_type"),
2026-01-04 18:58:20 +08:00
"MinCount": 1,
"MaxCount": 1,
}
2026-01-05 11:07:55 +08:00
if spec.get("instance_name"):
params["TagSpecifications"] = [
{
"ResourceType": "instance",
"Tags": [{"Key": "Name", "Value": spec["instance_name"]}],
}
]
subnet_id = spec.get("subnet_id")
if subnet_id:
params["SubnetId"] = subnet_id
security_group_ids = spec.get("security_group_ids")
if security_group_ids:
params["SecurityGroupIds"] = security_group_ids
block_mapping = _build_block_device_mappings(
spec.get("root_device"), spec.get("root_size"), spec.get("root_volume_type")
)
if block_mapping:
params["BlockDeviceMappings"] = block_mapping
if include_key and account.key_name:
2026-01-04 18:58:20 +08:00
params["KeyName"] = account.key_name
2026-01-05 11:07:55 +08:00
return params
def _run(params: dict) -> str:
2026-01-04 18:58:20 +08:00
resp = client.run_instances(**params)
instance_id = resp["Instances"][0]["InstanceId"]
_wait_for_state(client, instance_id, "instance_running")
return instance_id
2026-01-05 11:07:55 +08:00
try:
return _run(_build_params())
except ClientError as exc:
code = exc.response.get("Error", {}).get("Code") if hasattr(exc, "response") else None
if code == "InvalidKeyPair.NotFound" and account.key_name:
# fallback: retry without key pair
try:
return _run(_build_params(include_key=False))
except (ClientError, BotoCoreError) as exc2:
raise AWSOperationError(
f"Failed to create instance after removing missing key pair {account.key_name}: {exc2}"
) from exc
raise AWSOperationError(f"Failed to create instance: {exc}") from exc
except BotoCoreError as exc:
2026-01-04 18:58:20 +08:00
raise AWSOperationError(f"Failed to create instance: {exc}") from exc
def _get_public_ip(client, instance_id: str) -> str:
try:
resp = client.describe_instances(InstanceIds=[instance_id])
reservations = resp.get("Reservations", [])
if not reservations:
raise AWSOperationError("Instance not found when reading IP")
instance = reservations[0]["Instances"][0]
return instance.get("PublicIpAddress") or ""
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to fetch public IP: {exc}") from exc
def _recycle_ip_until_free(client, instance_id: str, banned_ips: set[str], retry_limit: int) -> str:
attempts = 0
while attempts < retry_limit:
current_ip = _get_public_ip(client, instance_id)
if current_ip and current_ip not in banned_ips:
return current_ip
try:
client.stop_instances(InstanceIds=[instance_id])
_wait_for_state(client, instance_id, "instance_stopped")
client.start_instances(InstanceIds=[instance_id])
_wait_for_state(client, instance_id, "instance_running")
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed while cycling IP: {exc}") from exc
attempts += 1
raise AWSOperationError("Reached retry limit while attempting to obtain a free IP")
2026-01-05 11:07:55 +08:00
def _get_network_out_mb(cw_client, instance_id: str, days: int = 30) -> float:
"""Fetch total NetworkOut over the past window (MB)."""
end = datetime.now(timezone.utc)
start = end - timedelta(days=days)
try:
resp = cw_client.get_metric_statistics(
Namespace="AWS/EC2",
MetricName="NetworkOut",
Dimensions=[{"Name": "InstanceId", "Value": instance_id}],
StartTime=start,
EndTime=end,
Period=3600 * 6, # 6 小时粒度,覆盖 30 天
Statistics=["Sum"],
)
datapoints = resp.get("Datapoints", [])
if not datapoints:
return 0.0
total_bytes = sum(dp.get("Sum", 0.0) for dp in datapoints)
return round(total_bytes / (1024 * 1024), 2)
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to fetch NetworkOut metrics: {exc}") from exc
def _build_spec_from_instance(client, instance: dict, account: AccountConfig) -> InstanceSpec:
instance_type = instance.get("InstanceType")
if not instance_type:
raise AWSOperationError("Failed to detect instance type from source instance")
root_device, root_size, root_volume_type = _get_root_volume_spec(client, instance)
return {
"instance_type": instance_type,
"instance_name": _extract_name_tag(instance),
"root_device": root_device,
"root_size": root_size,
"root_volume_type": root_volume_type,
"security_group_ids": _extract_security_group_ids(instance),
"security_group_names": _extract_security_group_names(instance),
"subnet_id": instance.get("SubnetId") or account.subnet_id,
"availability_zone": instance.get("Placement", {}).get("AvailabilityZone"),
"region": account.region,
}
2026-01-04 18:58:20 +08:00
def replace_instance_ip(
2026-01-05 11:07:55 +08:00
ip: str,
account: AccountConfig,
disallowed_ips: set[str],
retry_limit: int = 5,
fallback_spec: Optional[InstanceSpec] = None,
) -> Dict[str, object]:
2026-01-04 18:58:20 +08:00
client = ec2_client(account)
2026-01-05 11:07:55 +08:00
cw = cloudwatch_client(account)
instance = _get_instance_by_ip(client, ip)
spec: Optional[InstanceSpec] = None
instance_id: Optional[str] = None
network_out_mb: Optional[float] = None
if instance:
instance_id = instance["InstanceId"]
spec = _build_spec_from_instance(client, instance, account)
try:
network_out_mb = _get_network_out_mb(cw, instance_id)
except AWSOperationError:
network_out_mb = None
elif fallback_spec:
spec = fallback_spec
2026-01-04 18:58:20 +08:00
2026-01-05 11:07:55 +08:00
if not spec:
raise AWSOperationError(f"No instance found with IP {ip} 且数据库无该IP规格信息")
new_instance_id = _provision_instance(client, account, spec)
2026-01-04 18:58:20 +08:00
new_ip = _recycle_ip_until_free(client, new_instance_id, disallowed_ips, retry_limit)
2026-01-05 11:07:55 +08:00
if instance_id:
# 不阻塞新实例创建,终止旧实例但不等待完成
_terminate_instance(client, instance_id, wait_for_completion=False)
2026-01-04 18:58:20 +08:00
return {
"terminated_instance_id": instance_id,
"new_instance_id": new_instance_id,
"new_ip": new_ip,
2026-01-05 11:07:55 +08:00
"spec_used": spec,
"terminated_network_out_mb": network_out_mb,
2026-01-04 18:58:20 +08:00
}