2025-12-10 12:02:17 +08:00

285 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from typing import Any, Dict, List, Optional
import boto3
from botocore.config import Config as BotoConfig
from botocore.exceptions import ClientError
from backend.core.config import settings
from backend.modules.aws_accounts.models import AWSCredential, CredentialType
OPEN_ALL_SG_NAME = "panel-open-all"
def _boto_config() -> BotoConfig:
proxies = None
if settings.aws_proxy_url:
proxies = {"https": settings.aws_proxy_url, "http": settings.aws_proxy_url}
return BotoConfig(connect_timeout=settings.aws_timeout, read_timeout=settings.aws_timeout, proxies=proxies)
def build_session(credential: AWSCredential, region: str):
cfg = _boto_config()
if credential.credential_type == CredentialType.ACCESS_KEY:
session = boto3.Session(
aws_access_key_id=credential.access_key_id,
aws_secret_access_key=credential.secret_access_key,
region_name=region or credential.default_region,
)
return session, cfg
base_session = boto3.Session(
aws_access_key_id=credential.access_key_id,
aws_secret_access_key=credential.secret_access_key,
region_name=region or credential.default_region,
)
sts = base_session.client("sts", config=cfg, region_name=region or credential.default_region)
assume_kwargs: Dict[str, Any] = {"RoleArn": credential.role_arn, "RoleSessionName": "ec2-panel"}
if credential.external_id:
assume_kwargs["ExternalId"] = credential.external_id
resp = sts.assume_role(**assume_kwargs)
creds = resp["Credentials"]
session = boto3.Session(
aws_access_key_id=creds["AccessKeyId"],
aws_secret_access_key=creds["SecretAccessKey"],
aws_session_token=creds["SessionToken"],
region_name=region or credential.default_region,
)
return session, cfg
def describe_instances(
credential: AWSCredential,
region: str,
filters: Optional[List[Dict[str, Any]]] = None,
instance_ids: Optional[List[str]] = None,
) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
params: Dict[str, Any] = {}
if filters:
params["Filters"] = filters
if instance_ids:
params["InstanceIds"] = instance_ids
return client.describe_instances(**params)
def describe_instance_status(
credential: AWSCredential, region: str, instance_ids: List[str]
) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.describe_instance_status(InstanceIds=instance_ids, IncludeAllInstances=True)
def run_instances(
credential: AWSCredential,
region: str,
ami_id: str,
instance_type: str,
key_name: Optional[str],
security_groups: Optional[List[str]],
subnet_id: Optional[str],
block_device_mappings: Optional[List[Dict[str, Any]]] = None,
cpu_options: Optional[Dict[str, Any]] = None,
min_count: int = 1,
max_count: int = 1,
name_tag: Optional[str] = None,
user_data: Optional[str] = None,
) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
params: Dict[str, Any] = {
"ImageId": ami_id,
"InstanceType": instance_type,
"MinCount": min_count,
"MaxCount": max_count,
}
if key_name:
params["KeyName"] = key_name
if security_groups:
params["SecurityGroupIds"] = security_groups
if subnet_id:
params["SubnetId"] = subnet_id
if block_device_mappings:
params["BlockDeviceMappings"] = block_device_mappings
if cpu_options:
params["CreditSpecification"] = cpu_options
if name_tag:
params["TagSpecifications"] = [
{"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": name_tag}]},
{"ResourceType": "volume", "Tags": [{"Key": "Name", "Value": name_tag}]},
]
if user_data:
params["UserData"] = user_data
return client.run_instances(**params)
def start_instances(credential: AWSCredential, region: str, instance_ids: List[str]) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.start_instances(InstanceIds=instance_ids)
def stop_instances(credential: AWSCredential, region: str, instance_ids: List[str]) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.stop_instances(InstanceIds=instance_ids)
def reboot_instances(credential: AWSCredential, region: str, instance_ids: List[str]) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.reboot_instances(InstanceIds=instance_ids)
def terminate_instances(credential: AWSCredential, region: str, instance_ids: List[str]) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.terminate_instances(InstanceIds=instance_ids)
def get_service_quota(credential: AWSCredential, region: str, service_code: str, quota_code: str) -> Dict[str, Any]:
"""
Best-effort service quota lookup, used to hint at maximum runnable instances in a region.
"""
session, cfg = build_session(credential, region)
client = session.client("service-quotas", region_name=region or credential.default_region, config=cfg)
return client.get_service_quota(ServiceCode=service_code, QuotaCode=quota_code)
def describe_vpcs(credential: AWSCredential, region: str) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.describe_vpcs()
def describe_subnets(credential: AWSCredential, region: str, filters: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
params: Dict[str, Any] = {}
if filters:
params["Filters"] = filters
return client.describe_subnets(**params)
def describe_security_groups(
credential: AWSCredential, region: str, filters: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
params: Dict[str, Any] = {}
if filters:
params["Filters"] = filters
return client.describe_security_groups(**params)
def describe_key_pairs(credential: AWSCredential, region: str) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.describe_key_pairs()
def create_key_pair(credential: AWSCredential, region: str, key_name: str) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.create_key_pair(KeyName=key_name, KeyType="rsa", KeyFormat="pem")
def describe_regions(credential: AWSCredential) -> Dict[str, Any]:
session, cfg = build_session(credential, credential.default_region)
client = session.client("ec2", region_name=credential.default_region, config=cfg)
return client.describe_regions(AllRegions=True)
def describe_instance_types(credential: AWSCredential, region: str, filters: Optional[List[Dict[str, Any]]] = None) -> List[Dict[str, Any]]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
paginator = client.get_paginator("describe_instance_types")
params: Dict[str, Any] = {}
if filters:
params["Filters"] = filters
results: List[Dict[str, Any]] = []
for page in paginator.paginate(**params):
results.extend(page.get("InstanceTypes", []))
return results
def describe_images(credential: AWSCredential, region: str, image_ids: List[str]) -> Dict[str, Any]:
session, cfg = build_session(credential, region)
client = session.client("ec2", region_name=region or credential.default_region, config=cfg)
return client.describe_images(ImageIds=image_ids)
def _is_open_all_sg(sg: Dict[str, Any]) -> bool:
ingress = sg.get("IpPermissions", [])
egress = sg.get("IpPermissionsEgress", [])
def _has_all(perms: List[Dict[str, Any]]) -> bool:
for p in perms:
if p.get("IpProtocol") == "-1":
if any(r.get("CidrIp") == "0.0.0.0/0" for r in p.get("IpRanges", [])):
return True
return False
return _has_all(ingress) and _has_all(egress)
def ensure_open_all_sg_for_vpc(ec2_client, vpc_id: str) -> str:
"""Ensure an open-all security group exists in a VPC. Returns GroupId."""
resp = ec2_client.describe_security_groups(Filters=[{"Name": "vpc-id", "Values": [vpc_id]}])
default_sg_id = None
for sg in resp.get("SecurityGroups", []):
name = sg.get("GroupName")
if name == "default":
default_sg_id = sg.get("GroupId")
if (name == OPEN_ALL_SG_NAME or name == "default") and _is_open_all_sg(sg):
return sg["GroupId"]
# If default exists but不是全开直接把默认安全组放开
if default_sg_id:
ingress_rule = {
"IpProtocol": "-1",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}],
}
egress_rule = {
"IpProtocol": "-1",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}],
}
try:
ec2_client.authorize_security_group_ingress(GroupId=default_sg_id, IpPermissions=[ingress_rule])
except ClientError as exc: # noqa: PERF203
if exc.response.get("Error", {}).get("Code") != "InvalidPermission.Duplicate":
raise
try:
ec2_client.authorize_security_group_egress(GroupId=default_sg_id, IpPermissions=[egress_rule])
except ClientError as exc: # noqa: PERF203
if exc.response.get("Error", {}).get("Code") != "InvalidPermission.Duplicate":
raise
return default_sg_id
# create new
create_resp = ec2_client.create_security_group(
GroupName=OPEN_ALL_SG_NAME,
Description="Open all inbound/outbound for panel-created instances",
VpcId=vpc_id,
)
sg_id = create_resp["GroupId"]
ingress_rule = {
"IpProtocol": "-1",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}],
}
egress_rule = {
"IpProtocol": "-1",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}],
}
try:
ec2_client.authorize_security_group_ingress(GroupId=sg_id, IpPermissions=[ingress_rule])
except ClientError as exc: # noqa: PERF203
if exc.response.get("Error", {}).get("Code") != "InvalidPermission.Duplicate":
raise
try:
ec2_client.authorize_security_group_egress(GroupId=sg_id, IpPermissions=[egress_rule])
except ClientError as exc: # noqa: PERF203
if exc.response.get("Error", {}).get("Code") != "InvalidPermission.Duplicate":
raise
return sg_id