from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import Dict, List, Optional, TypedDict import boto3 from botocore.exceptions import BotoCoreError, ClientError class ConfigError(Exception): pass class AWSOperationError(Exception): pass class InstanceSpec(TypedDict, total=False): instance_type: Optional[str] instance_name: Optional[str] root_device: Optional[str] root_size: Optional[int] root_volume_type: Optional[str] security_group_ids: List[str] security_group_names: List[str] subnet_id: Optional[str] availability_zone: Optional[str] region: Optional[str] @dataclass class AccountConfig: name: str region: str access_key_id: str secret_access_key: str ami_id: str subnet_id: Optional[str] = None security_group_ids: List[str] = field(default_factory=list) key_name: Optional[str] = None def ec2_client(account: AccountConfig): return boto3.client( "ec2", region_name=account.region, aws_access_key_id=account.access_key_id, aws_secret_access_key=account.secret_access_key, ) def cloudwatch_client(account: AccountConfig): return boto3.client( "cloudwatch", region_name=account.region, aws_access_key_id=account.access_key_id, aws_secret_access_key=account.secret_access_key, ) def _get_instance_by_ip(client, ip: str) -> Optional[dict]: filters = [ {"Name": "instance-state-name", "Values": ["pending", "running", "stopping", "stopped"]}, ] for field in ["ip-address", "private-ip-address"]: try: resp = client.describe_instances(Filters=filters + [{"Name": field, "Values": [ip]}]) except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to describe instances: {exc}") from exc for reservation in resp.get("Reservations", []): for instance in reservation.get("Instances", []): return instance return None def _wait_for_state(client, instance_id: str, waiter_name: str) -> None: waiter = client.get_waiter(waiter_name) waiter.wait(InstanceIds=[instance_id]) def _get_root_volume_spec(client, instance: dict) -> tuple[Optional[str], Optional[int], Optional[str]]: """Return (device_name, size_gb, volume_type) for root volume if available.""" root_device_name = instance.get("RootDeviceName") if not root_device_name: return None, None, None for mapping in instance.get("BlockDeviceMappings", []): if mapping.get("DeviceName") != root_device_name: continue ebs = mapping.get("Ebs") if not ebs: return root_device_name, None, None volume_id = ebs.get("VolumeId") if not volume_id: return root_device_name, None, None try: vol_resp = client.describe_volumes(VolumeIds=[volume_id]) volumes = vol_resp.get("Volumes", []) if volumes: volume = volumes[0] return root_device_name, volume.get("Size"), volume.get("VolumeType") except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to read volume info for {volume_id}: {exc}") from exc return root_device_name, None, None def _extract_security_group_ids(instance: dict) -> List[str]: groups = [] for g in instance.get("SecurityGroups", []): gid = g.get("GroupId") if gid: groups.append(gid) return groups def _extract_security_group_names(instance: dict) -> List[str]: groups = [] for g in instance.get("SecurityGroups", []): name = g.get("GroupName") if name: groups.append(name) return groups def _extract_name_tag(instance: dict) -> Optional[str]: for tag in instance.get("Tags", []) or []: if tag.get("Key") == "Name": return tag.get("Value") return None def _terminate_instance(client, instance_id: str, wait_for_completion: bool = True) -> None: try: client.terminate_instances(InstanceIds=[instance_id]) if wait_for_completion: _wait_for_state(client, instance_id, "instance_terminated") except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to terminate instance {instance_id}: {exc}") from exc def _build_block_device_mappings( device_name: Optional[str], volume_size: Optional[int], volume_type: Optional[str] ) -> Optional[list]: if not device_name: return None ebs = {"DeleteOnTermination": True} if volume_type: ebs["VolumeType"] = volume_type if volume_size: ebs["VolumeSize"] = volume_size return [{"DeviceName": device_name, "Ebs": ebs}] def _provision_instance( client, account: AccountConfig, spec: InstanceSpec, ) -> str: def _build_params(include_key: bool = True) -> dict: params = { "ImageId": account.ami_id, "InstanceType": spec.get("instance_type"), "MinCount": 1, "MaxCount": 1, } if spec.get("instance_name"): params["TagSpecifications"] = [ { "ResourceType": "instance", "Tags": [{"Key": "Name", "Value": spec["instance_name"]}], } ] subnet_id = spec.get("subnet_id") if subnet_id: params["SubnetId"] = subnet_id security_group_ids = spec.get("security_group_ids") if security_group_ids: params["SecurityGroupIds"] = security_group_ids block_mapping = _build_block_device_mappings( spec.get("root_device"), spec.get("root_size"), spec.get("root_volume_type") ) if block_mapping: params["BlockDeviceMappings"] = block_mapping if include_key and account.key_name: params["KeyName"] = account.key_name return params def _run(params: dict) -> str: resp = client.run_instances(**params) instance_id = resp["Instances"][0]["InstanceId"] _wait_for_state(client, instance_id, "instance_running") return instance_id try: return _run(_build_params()) except ClientError as exc: code = exc.response.get("Error", {}).get("Code") if hasattr(exc, "response") else None if code == "InvalidKeyPair.NotFound" and account.key_name: # fallback: retry without key pair try: return _run(_build_params(include_key=False)) except (ClientError, BotoCoreError) as exc2: raise AWSOperationError( f"Failed to create instance after removing missing key pair {account.key_name}: {exc2}" ) from exc raise AWSOperationError(f"Failed to create instance: {exc}") from exc except BotoCoreError as exc: raise AWSOperationError(f"Failed to create instance: {exc}") from exc def _get_public_ip(client, instance_id: str) -> str: try: resp = client.describe_instances(InstanceIds=[instance_id]) reservations = resp.get("Reservations", []) if not reservations: raise AWSOperationError("Instance not found when reading IP") instance = reservations[0]["Instances"][0] return instance.get("PublicIpAddress") or "" except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to fetch public IP: {exc}") from exc def _recycle_ip_until_free(client, instance_id: str, banned_ips: set[str], retry_limit: int) -> str: attempts = 0 while attempts < retry_limit: current_ip = _get_public_ip(client, instance_id) if current_ip and current_ip not in banned_ips: return current_ip try: client.stop_instances(InstanceIds=[instance_id]) _wait_for_state(client, instance_id, "instance_stopped") client.start_instances(InstanceIds=[instance_id]) _wait_for_state(client, instance_id, "instance_running") except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed while cycling IP: {exc}") from exc attempts += 1 raise AWSOperationError("Reached retry limit while attempting to obtain a free IP") def _get_network_out_mb(cw_client, instance_id: str, days: int = 30) -> float: """Fetch total NetworkOut over the past window (MB).""" end = datetime.now(timezone.utc) start = end - timedelta(days=days) try: resp = cw_client.get_metric_statistics( Namespace="AWS/EC2", MetricName="NetworkOut", Dimensions=[{"Name": "InstanceId", "Value": instance_id}], StartTime=start, EndTime=end, Period=3600 * 6, # 6 小时粒度,覆盖 30 天 Statistics=["Sum"], ) datapoints = resp.get("Datapoints", []) if not datapoints: return 0.0 total_bytes = sum(dp.get("Sum", 0.0) for dp in datapoints) return round(total_bytes / (1024 * 1024), 2) except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to fetch NetworkOut metrics: {exc}") from exc def _build_spec_from_instance(client, instance: dict, account: AccountConfig) -> InstanceSpec: instance_type = instance.get("InstanceType") if not instance_type: raise AWSOperationError("Failed to detect instance type from source instance") root_device, root_size, root_volume_type = _get_root_volume_spec(client, instance) return { "instance_type": instance_type, "instance_name": _extract_name_tag(instance), "root_device": root_device, "root_size": root_size, "root_volume_type": root_volume_type, "security_group_ids": _extract_security_group_ids(instance), "security_group_names": _extract_security_group_names(instance), "subnet_id": instance.get("SubnetId") or account.subnet_id, "availability_zone": instance.get("Placement", {}).get("AvailabilityZone"), "region": account.region, } def replace_instance_ip( ip: str, account: AccountConfig, disallowed_ips: set[str], retry_limit: int = 5, fallback_spec: Optional[InstanceSpec] = None, ) -> Dict[str, object]: client = ec2_client(account) cw = cloudwatch_client(account) instance = _get_instance_by_ip(client, ip) spec: Optional[InstanceSpec] = None instance_id: Optional[str] = None network_out_mb: Optional[float] = None if instance: instance_id = instance["InstanceId"] spec = _build_spec_from_instance(client, instance, account) try: network_out_mb = _get_network_out_mb(cw, instance_id) except AWSOperationError: network_out_mb = None elif fallback_spec: spec = fallback_spec if not spec: raise AWSOperationError(f"No instance found with IP {ip} 且数据库无该IP规格信息") new_instance_id = _provision_instance(client, account, spec) new_ip = _recycle_ip_until_free(client, new_instance_id, disallowed_ips, retry_limit) if instance_id: # 不阻塞新实例创建,终止旧实例但不等待完成 _terminate_instance(client, instance_id, wait_for_completion=False) return { "terminated_instance_id": instance_id, "new_instance_id": new_instance_id, "new_ip": new_ip, "spec_used": spec, "terminated_network_out_mb": network_out_mb, }