import os from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import Dict, List, Optional, TypedDict import boto3 from botocore.exceptions import BotoCoreError, ClientError import yaml class ConfigError(Exception): pass class AWSOperationError(Exception): pass class InstanceSpec(TypedDict, total=False): instance_type: Optional[str] instance_name: Optional[str] root_device: Optional[str] root_size: Optional[int] root_volume_type: Optional[str] security_group_ids: List[str] security_group_names: List[str] subnet_id: Optional[str] availability_zone: Optional[str] region: Optional[str] @dataclass class AccountConfig: name: str region: str access_key_id: str secret_access_key: str ami_id: str subnet_id: Optional[str] = None security_group_ids: List[str] = field(default_factory=list) key_name: Optional[str] = None def load_account_configs(path: str) -> Dict[str, AccountConfig]: if not os.path.exists(path): raise ConfigError(f"Config file not found at {path}") with open(path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) if not data or "accounts" not in data: raise ConfigError("accounts.yaml missing 'accounts' list") accounts = {} for item in data["accounts"]: cfg = AccountConfig( name=item["name"], region=item["region"], access_key_id=item["access_key_id"], secret_access_key=item["secret_access_key"], ami_id=item["ami_id"], subnet_id=item.get("subnet_id"), security_group_ids=item.get("security_group_ids", []), key_name=item.get("key_name"), ) accounts[cfg.name] = cfg return accounts def ec2_client(account: AccountConfig): return boto3.client( "ec2", region_name=account.region, aws_access_key_id=account.access_key_id, aws_secret_access_key=account.secret_access_key, ) def cloudwatch_client(account: AccountConfig): return boto3.client( "cloudwatch", region_name=account.region, aws_access_key_id=account.access_key_id, aws_secret_access_key=account.secret_access_key, ) def _get_instance_by_ip(client, ip: str) -> Optional[dict]: filters = [ {"Name": "instance-state-name", "Values": ["pending", "running", "stopping", "stopped"]}, ] for field in ["ip-address", "private-ip-address"]: try: resp = client.describe_instances(Filters=filters + [{"Name": field, "Values": [ip]}]) except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to describe instances: {exc}") from exc for reservation in resp.get("Reservations", []): for instance in reservation.get("Instances", []): return instance return None def _wait_for_state(client, instance_id: str, waiter_name: str) -> None: waiter = client.get_waiter(waiter_name) waiter.wait(InstanceIds=[instance_id]) def _get_root_volume_spec(client, instance: dict) -> tuple[Optional[str], Optional[int], Optional[str]]: """Return (device_name, size_gb, volume_type) for root volume if available.""" root_device_name = instance.get("RootDeviceName") if not root_device_name: return None, None, None for mapping in instance.get("BlockDeviceMappings", []): if mapping.get("DeviceName") != root_device_name: continue ebs = mapping.get("Ebs") if not ebs: return root_device_name, None, None volume_id = ebs.get("VolumeId") if not volume_id: return root_device_name, None, None try: vol_resp = client.describe_volumes(VolumeIds=[volume_id]) volumes = vol_resp.get("Volumes", []) if volumes: volume = volumes[0] return root_device_name, volume.get("Size"), volume.get("VolumeType") except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to read volume info for {volume_id}: {exc}") from exc return root_device_name, None, None def _extract_security_group_ids(instance: dict) -> List[str]: groups = [] for g in instance.get("SecurityGroups", []): gid = g.get("GroupId") if gid: groups.append(gid) return groups def _extract_security_group_names(instance: dict) -> List[str]: groups = [] for g in instance.get("SecurityGroups", []): name = g.get("GroupName") if name: groups.append(name) return groups def _extract_name_tag(instance: dict) -> Optional[str]: for tag in instance.get("Tags", []) or []: if tag.get("Key") == "Name": return tag.get("Value") return None def _terminate_instance(client, instance_id: str, wait_for_completion: bool = True) -> None: try: client.terminate_instances(InstanceIds=[instance_id]) if wait_for_completion: _wait_for_state(client, instance_id, "instance_terminated") except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to terminate instance {instance_id}: {exc}") from exc def _build_block_device_mappings( device_name: Optional[str], volume_size: Optional[int], volume_type: Optional[str] ) -> Optional[list]: if not device_name: return None ebs = {"DeleteOnTermination": True} if volume_type: ebs["VolumeType"] = volume_type if volume_size: ebs["VolumeSize"] = volume_size return [{"DeviceName": device_name, "Ebs": ebs}] def _provision_instance( client, account: AccountConfig, spec: InstanceSpec, ) -> str: def _build_params(include_key: bool = True) -> dict: params = { "ImageId": account.ami_id, "InstanceType": spec.get("instance_type"), "MinCount": 1, "MaxCount": 1, } if spec.get("instance_name"): params["TagSpecifications"] = [ { "ResourceType": "instance", "Tags": [{"Key": "Name", "Value": spec["instance_name"]}], } ] subnet_id = spec.get("subnet_id") if subnet_id: params["SubnetId"] = subnet_id security_group_ids = spec.get("security_group_ids") if security_group_ids: params["SecurityGroupIds"] = security_group_ids block_mapping = _build_block_device_mappings( spec.get("root_device"), spec.get("root_size"), spec.get("root_volume_type") ) if block_mapping: params["BlockDeviceMappings"] = block_mapping if include_key and account.key_name: params["KeyName"] = account.key_name return params def _run(params: dict) -> str: resp = client.run_instances(**params) instance_id = resp["Instances"][0]["InstanceId"] _wait_for_state(client, instance_id, "instance_running") return instance_id try: return _run(_build_params()) except ClientError as exc: code = exc.response.get("Error", {}).get("Code") if hasattr(exc, "response") else None if code == "InvalidKeyPair.NotFound" and account.key_name: # fallback: retry without key pair try: return _run(_build_params(include_key=False)) except (ClientError, BotoCoreError) as exc2: raise AWSOperationError( f"Failed to create instance after removing missing key pair {account.key_name}: {exc2}" ) from exc raise AWSOperationError(f"Failed to create instance: {exc}") from exc except BotoCoreError as exc: raise AWSOperationError(f"Failed to create instance: {exc}") from exc def _get_public_ip(client, instance_id: str) -> str: try: resp = client.describe_instances(InstanceIds=[instance_id]) reservations = resp.get("Reservations", []) if not reservations: raise AWSOperationError("Instance not found when reading IP") instance = reservations[0]["Instances"][0] return instance.get("PublicIpAddress") or "" except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to fetch public IP: {exc}") from exc def _recycle_ip_until_free(client, instance_id: str, banned_ips: set[str], retry_limit: int) -> str: attempts = 0 while attempts < retry_limit: current_ip = _get_public_ip(client, instance_id) if current_ip and current_ip not in banned_ips: return current_ip try: client.stop_instances(InstanceIds=[instance_id]) _wait_for_state(client, instance_id, "instance_stopped") client.start_instances(InstanceIds=[instance_id]) _wait_for_state(client, instance_id, "instance_running") except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed while cycling IP: {exc}") from exc attempts += 1 raise AWSOperationError("Reached retry limit while attempting to obtain a free IP") def _get_network_out_mb(cw_client, instance_id: str, days: int = 30) -> float: """Fetch total NetworkOut over the past window (MB).""" end = datetime.now(timezone.utc) start = end - timedelta(days=days) try: resp = cw_client.get_metric_statistics( Namespace="AWS/EC2", MetricName="NetworkOut", Dimensions=[{"Name": "InstanceId", "Value": instance_id}], StartTime=start, EndTime=end, Period=3600 * 6, # 6 小时粒度,覆盖 30 天 Statistics=["Sum"], ) datapoints = resp.get("Datapoints", []) if not datapoints: return 0.0 total_bytes = sum(dp.get("Sum", 0.0) for dp in datapoints) return round(total_bytes / (1024 * 1024), 2) except (ClientError, BotoCoreError) as exc: raise AWSOperationError(f"Failed to fetch NetworkOut metrics: {exc}") from exc def _build_spec_from_instance(client, instance: dict, account: AccountConfig) -> InstanceSpec: instance_type = instance.get("InstanceType") if not instance_type: raise AWSOperationError("Failed to detect instance type from source instance") root_device, root_size, root_volume_type = _get_root_volume_spec(client, instance) return { "instance_type": instance_type, "instance_name": _extract_name_tag(instance), "root_device": root_device, "root_size": root_size, "root_volume_type": root_volume_type, "security_group_ids": _extract_security_group_ids(instance), "security_group_names": _extract_security_group_names(instance), "subnet_id": instance.get("SubnetId") or account.subnet_id, "availability_zone": instance.get("Placement", {}).get("AvailabilityZone"), "region": account.region, } def replace_instance_ip( ip: str, account: AccountConfig, disallowed_ips: set[str], retry_limit: int = 5, fallback_spec: Optional[InstanceSpec] = None, ) -> Dict[str, object]: client = ec2_client(account) cw = cloudwatch_client(account) instance = _get_instance_by_ip(client, ip) spec: Optional[InstanceSpec] = None instance_id: Optional[str] = None network_out_mb: Optional[float] = None if instance: instance_id = instance["InstanceId"] spec = _build_spec_from_instance(client, instance, account) try: network_out_mb = _get_network_out_mb(cw, instance_id) except AWSOperationError: network_out_mb = None elif fallback_spec: spec = fallback_spec if not spec: raise AWSOperationError(f"No instance found with IP {ip} 且数据库无该IP规格信息") new_instance_id = _provision_instance(client, account, spec) new_ip = _recycle_ip_until_free(client, new_instance_id, disallowed_ips, retry_limit) if instance_id: # 不阻塞新实例创建,终止旧实例但不等待完成 _terminate_instance(client, instance_id, wait_for_completion=False) return { "terminated_instance_id": instance_id, "new_instance_id": new_instance_id, "new_ip": new_ip, "spec_used": spec, "terminated_network_out_mb": network_out_mb, }