AWS-Panel/backend/routers/instances.py

517 lines
20 KiB
Python
Raw Permalink Normal View History

2025-12-10 12:02:17 +08:00
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
from typing import Dict, List, Optional
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import and_, func, or_, select
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from .. import aws_ops
from ..db import SessionLocal, get_session
from ..dependencies import AuthUser, get_current_user
from ..models import (
AWSCredential,
AuditAction,
AuditResourceType,
CustomerCredential,
Instance,
InstanceDesiredStatus,
InstanceStatus,
Job,
JobItem,
JobItemAction,
JobItemResourceType,
JobItemStatus,
JobStatus,
JobType,
)
from ..schemas import (
InstanceCreateRequest,
InstanceFilterParams,
InstanceListResponse,
InstanceOut,
InstanceSyncRequest,
JobOut,
)
from ..utils.audit import create_audit_log
router = APIRouter(prefix="/api/v1/instances", tags=["instances"])
STATE_MAP: Dict[str, InstanceStatus] = {
"pending": InstanceStatus.PENDING,
"running": InstanceStatus.RUNNING,
"stopping": InstanceStatus.STOPPING,
"stopped": InstanceStatus.STOPPED,
"shutting-down": InstanceStatus.SHUTTING_DOWN,
"terminated": InstanceStatus.TERMINATED,
}
JOB_TYPE_BY_ACTION = {
JobItemAction.START: JobType.START_INSTANCES,
JobItemAction.STOP: JobType.STOP_INSTANCES,
JobItemAction.REBOOT: JobType.REBOOT_INSTANCES,
JobItemAction.TERMINATE: JobType.TERMINATE_INSTANCES,
}
DESIRED_BY_ACTION = {
JobItemAction.START: InstanceDesiredStatus.RUNNING,
JobItemAction.STOP: InstanceDesiredStatus.STOPPED,
JobItemAction.REBOOT: InstanceDesiredStatus.RUNNING,
JobItemAction.TERMINATE: InstanceDesiredStatus.TERMINATED,
}
def _map_state(state: Optional[str]) -> InstanceStatus:
if not state:
return InstanceStatus.UNKNOWN
return STATE_MAP.get(state.lower(), InstanceStatus.UNKNOWN)
def _extract_name_tag(tags: Optional[List[Dict[str, str]]]) -> Optional[str]:
if not tags:
return None
for tag in tags:
if tag.get("Key") == "Name":
return tag.get("Value")
return None
async def _ensure_credential_access(
credential_id: int, auth_user: AuthUser, session: AsyncSession
) -> AWSCredential:
cred = await session.get(AWSCredential, credential_id)
if not cred or not cred.is_active:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found or disabled")
if auth_user.role_name == "ADMIN":
return cred
mapping = await session.scalar(
select(CustomerCredential).where(
and_(
CustomerCredential.customer_id == auth_user.customer_id,
CustomerCredential.credential_id == credential_id,
CustomerCredential.is_allowed == 1,
)
)
)
if not mapping:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Credential not allowed")
return cred
@router.get("", response_model=InstanceListResponse)
async def list_instances(
filters: InstanceFilterParams = Depends(),
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> InstanceListResponse:
query = select(Instance)
if auth_user.role_name != "ADMIN":
query = query.where(Instance.customer_id == auth_user.customer_id)
elif filters.customer_id:
query = query.where(Instance.customer_id == filters.customer_id)
if filters.credential_id:
query = query.where(Instance.credential_id == filters.credential_id)
if filters.account_id:
query = query.where(Instance.account_id == filters.account_id)
if filters.region:
query = query.where(Instance.region == filters.region)
if filters.status:
query = query.where(Instance.status == filters.status)
if filters.keyword:
pattern = f"%{filters.keyword}%"
query = query.where(
or_(
Instance.name_tag.ilike(pattern),
Instance.instance_id.ilike(pattern),
Instance.public_ip.ilike(pattern),
Instance.private_ip.ilike(pattern),
)
)
total = await session.scalar(select(func.count()).select_from(query.subquery()))
instances = (
await session.scalars(query.order_by(Instance.updated_at.desc()).offset(filters.offset).limit(filters.limit))
).all()
return InstanceListResponse(items=[InstanceOut.model_validate(i) for i in instances], total=total or 0)
@router.post("/create", response_model=InstanceOut, status_code=status.HTTP_201_CREATED)
async def create_instance(
payload: InstanceCreateRequest,
request: Request,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> InstanceOut:
customer_id = payload.customer_id or auth_user.customer_id
if not customer_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required")
cred = await _ensure_credential_access(payload.credential_id, auth_user, session)
if cred.account_id != payload.account_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="account_id mismatch")
region = payload.region or cred.default_region
try:
resp = await asyncio.to_thread(
aws_ops.run_instances,
cred,
region,
payload.ami_id,
payload.instance_type,
payload.key_name,
payload.security_groups,
payload.subnet_id,
1,
1,
payload.name_tag,
)
except Exception as exc: # pragma: no cover - AWS failure path
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
instances = resp.get("Instances") or []
if not instances:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="AWS did not return instances")
data = instances[0]
instance_id = data.get("InstanceId")
name_tag = payload.name_tag or _extract_name_tag(data.get("Tags"))
security_groups = [sg.get("GroupId") for sg in data.get("SecurityGroups", []) if sg.get("GroupId")]
db_instance = Instance(
customer_id=customer_id,
credential_id=cred.id,
account_id=cred.account_id,
region=region,
az=(data.get("Placement") or {}).get("AvailabilityZone"),
instance_id=instance_id,
name_tag=name_tag,
instance_type=payload.instance_type,
ami_id=payload.ami_id,
key_name=payload.key_name,
public_ip=data.get("PublicIpAddress"),
private_ip=data.get("PrivateIpAddress"),
status=_map_state((data.get("State") or {}).get("Name")),
desired_status=InstanceDesiredStatus.RUNNING,
security_groups=security_groups,
subnet_id=data.get("SubnetId"),
vpc_id=data.get("VpcId"),
launched_at=data.get("LaunchTime"),
last_sync=datetime.now(timezone.utc),
last_cloud_state={"state": data.get("State"), "tags": data.get("Tags")},
)
session.add(db_instance)
await session.commit()
await session.refresh(db_instance)
await create_audit_log(
session,
user_id=auth_user.user.id,
customer_id=customer_id,
action=AuditAction.INSTANCE_CREATE,
resource_type=AuditResourceType.INSTANCE,
resource_id=db_instance.id,
description=f"Create instance {db_instance.instance_id}",
payload=payload.model_dump(),
request=request,
)
await session.commit()
return InstanceOut.model_validate(db_instance)
async def _enqueue_instance_action(
instance: Instance, action: JobItemAction, auth_user: AuthUser, session: AsyncSession
) -> Job:
job = Job(
job_uuid=uuid4().hex,
job_type=JOB_TYPE_BY_ACTION[action],
status=JobStatus.PENDING,
progress=0,
total_count=1,
created_by_user_id=auth_user.user.id,
created_for_customer=instance.customer_id,
)
session.add(job)
await session.flush()
job_item = JobItem(
job_id=job.id,
resource_type=JobItemResourceType.INSTANCE,
resource_id=instance.id,
account_id=instance.account_id,
region=instance.region,
instance_id=instance.instance_id,
action=action,
status=JobItemStatus.PENDING,
)
instance.desired_status = DESIRED_BY_ACTION[action]
session.add(job_item)
await session.commit()
await session.refresh(job)
await session.refresh(job_item)
asyncio.create_task(_process_instance_action(job.id, job_item.id, action))
return job
async def _process_instance_action(job_id: int, job_item_id: int, action: JobItemAction) -> None:
async with SessionLocal() as session:
job = await session.get(Job, job_id)
job_item = await session.scalar(
select(JobItem)
.where(JobItem.id == job_item_id)
.options(selectinload(JobItem.instance).selectinload(Instance.credential))
)
if not job or not job_item or not job_item.instance or not job_item.instance.credential:
return
job.status = JobStatus.RUNNING
job.started_at = datetime.now(timezone.utc)
job_item.status = JobItemStatus.RUNNING
await session.commit()
cred = job_item.instance.credential
region = job_item.region or job_item.instance.region
instance_id = job_item.instance.instance_id
try:
if action == JobItemAction.START:
resp = await asyncio.to_thread(aws_ops.start_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.RUNNING
elif action == JobItemAction.STOP:
resp = await asyncio.to_thread(aws_ops.stop_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.STOPPED
elif action == JobItemAction.REBOOT:
resp = await asyncio.to_thread(aws_ops.reboot_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.RUNNING
elif action == JobItemAction.TERMINATE:
resp = await asyncio.to_thread(aws_ops.terminate_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.TERMINATED
job_item.instance.terminated_at = datetime.now(timezone.utc)
else: # pragma: no cover
resp = {}
job_item.extra = resp
job_item.status = JobItemStatus.SUCCESS
job.success_count = 1
job.total_count = 1
job.progress = 100
job.status = JobStatus.SUCCESS
job.finished_at = datetime.now(timezone.utc)
await session.commit()
except Exception as exc: # pragma: no cover - AWS failure path
job_item.status = JobItemStatus.FAILED
job_item.error_message = str(exc)
job.status = JobStatus.FAILED
job.error_message = str(exc)
job.fail_count = 1
job.progress = 100
job.finished_at = datetime.now(timezone.utc)
await session.commit()
async def _get_instance_or_404(instance_id: int, session: AsyncSession, auth_user: AuthUser) -> Instance:
instance = await session.get(Instance, instance_id)
if not instance:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Instance not found")
if auth_user.role_name != "ADMIN" and instance.customer_id != auth_user.customer_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
if not instance.credential_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Instance missing credential")
return instance
async def _action_endpoint(
instance_id: int,
action: JobItemAction,
request: Request,
session: AsyncSession,
auth_user: AuthUser,
) -> JobOut:
instance = await _get_instance_or_404(instance_id, session, auth_user)
await _ensure_credential_access(instance.credential_id, auth_user, session)
job = await _enqueue_instance_action(instance, action, auth_user, session)
await create_audit_log(
session,
user_id=auth_user.user.id,
customer_id=instance.customer_id,
action={
JobItemAction.START: AuditAction.INSTANCE_START,
JobItemAction.STOP: AuditAction.INSTANCE_STOP,
JobItemAction.REBOOT: AuditAction.INSTANCE_REBOOT,
JobItemAction.TERMINATE: AuditAction.INSTANCE_TERMINATE,
}[action],
resource_type=AuditResourceType.INSTANCE,
resource_id=instance.id,
description=f"{action.value} instance {instance.instance_id}",
request=request,
)
await session.commit()
return JobOut.model_validate(job)
@router.post("/{instance_id}/start", response_model=JobOut)
async def start_instance(
instance_id: int,
request: Request,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> JobOut:
return await _action_endpoint(instance_id, JobItemAction.START, request, session, auth_user)
@router.post("/{instance_id}/stop", response_model=JobOut)
async def stop_instance(
instance_id: int,
request: Request,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> JobOut:
return await _action_endpoint(instance_id, JobItemAction.STOP, request, session, auth_user)
@router.post("/{instance_id}/reboot", response_model=JobOut)
async def reboot_instance(
instance_id: int,
request: Request,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> JobOut:
return await _action_endpoint(instance_id, JobItemAction.REBOOT, request, session, auth_user)
@router.post("/{instance_id}/terminate", response_model=JobOut)
async def terminate_instance(
instance_id: int,
request: Request,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> JobOut:
return await _action_endpoint(instance_id, JobItemAction.TERMINATE, request, session, auth_user)
@router.post("/sync", response_model=JobOut)
async def sync_instances(
payload: InstanceSyncRequest,
request: Request,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> JobOut:
target_customer_id = payload.customer_id or auth_user.customer_id
if not target_customer_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required")
credentials_query = (
select(AWSCredential)
.join(CustomerCredential, CustomerCredential.credential_id == AWSCredential.id)
.where(CustomerCredential.customer_id == target_customer_id)
.where(CustomerCredential.is_allowed == 1)
.where(AWSCredential.is_active == 1)
)
if payload.credential_id:
credentials_query = credentials_query.where(AWSCredential.id == payload.credential_id)
credentials = (await session.scalars(credentials_query)).all()
if not credentials:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No credentials to sync")
job = Job(
job_uuid=uuid4().hex,
job_type=JobType.SYNC_INSTANCES,
status=JobStatus.RUNNING,
progress=0,
total_count=0,
created_by_user_id=auth_user.user.id,
created_for_customer=target_customer_id,
payload=payload.model_dump(),
started_at=datetime.now(timezone.utc),
)
session.add(job)
await session.commit()
await session.refresh(job)
synced_count = 0
now = datetime.now(timezone.utc)
try:
for cred in credentials:
region = payload.region or cred.default_region
try:
resp = await asyncio.to_thread(aws_ops.describe_instances, cred, region)
except Exception as exc: # pragma: no cover
continue
reservations = resp.get("Reservations") or []
for res in reservations:
for inst in res.get("Instances", []):
instance_id = inst.get("InstanceId")
state_name = (inst.get("State") or {}).get("Name")
name_tag = _extract_name_tag(inst.get("Tags"))
security_groups = [sg.get("GroupId") for sg in inst.get("SecurityGroups", []) if sg.get("GroupId")]
record = dict(
customer_id=target_customer_id,
credential_id=cred.id,
account_id=cred.account_id,
region=region,
az=(inst.get("Placement") or {}).get("AvailabilityZone"),
instance_id=instance_id,
name_tag=name_tag,
instance_type=inst.get("InstanceType"),
ami_id=inst.get("ImageId"),
key_name=inst.get("KeyName"),
public_ip=inst.get("PublicIpAddress"),
private_ip=inst.get("PrivateIpAddress"),
status=_map_state(state_name),
desired_status=None,
security_groups=security_groups,
subnet_id=inst.get("SubnetId"),
vpc_id=inst.get("VpcId"),
launched_at=inst.get("LaunchTime"),
terminated_at=now if state_name == "terminated" else None,
last_sync=now,
last_cloud_state={"state": inst.get("State"), "tags": inst.get("Tags")},
)
stmt = insert(Instance).values(**record)
update_cols = {k: stmt.inserted[k] for k in record.keys() if k not in ("id",)}
await session.execute(stmt.on_duplicate_key_update(**update_cols))
db_inst = await session.scalar(
select(Instance).where(
Instance.account_id == cred.account_id,
Instance.region == region,
Instance.instance_id == instance_id,
)
)
session.add(
JobItem(
job_id=job.id,
resource_type=JobItemResourceType.INSTANCE,
resource_id=db_inst.id if db_inst else None,
account_id=cred.account_id,
region=region,
instance_id=instance_id,
action=JobItemAction.SYNC,
status=JobItemStatus.SUCCESS,
)
)
synced_count += 1
job.total_count = synced_count
job.success_count = synced_count
job.status = JobStatus.SUCCESS
job.progress = 100
job.finished_at = datetime.now(timezone.utc)
await create_audit_log(
session,
user_id=auth_user.user.id,
customer_id=target_customer_id,
action=AuditAction.INSTANCE_SYNC,
resource_type=AuditResourceType.INSTANCE,
resource_id=None,
description=f"Sync instances with {synced_count} records",
payload=payload.model_dump(),
request=request,
)
await session.commit()
except Exception as exc: # pragma: no cover - sync failure path
job.status = JobStatus.FAILED
job.error_message = str(exc)
job.progress = 100
job.finished_at = datetime.now(timezone.utc)
await session.commit()
return JobOut.model_validate(job)