import asyncio import logging from datetime import datetime from io import BytesIO from fastapi import APIRouter, Body, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func, select from sqlalchemy.orm import selectinload from fastapi.responses import StreamingResponse from backend.api.deps import AuthUser, get_current_user from backend.core.security import decode_token from backend.db.session import get_session from backend.modules.customers.models import Customer from backend.modules.instances.events import instance_event_manager from backend.modules.instances.models import Instance from backend.modules.instances.schemas import ( InstanceCreateRequest, InstanceCreateResponse, InstanceFilterParams, InstanceListResponse, InstanceOut, InstanceSyncRequest, BatchInstancesActionIn, BatchInstancesActionOut, BatchInstancesByIpIn, BatchInstancesByIpOut, InstanceIdsExportIn, ) from backend.modules.instances.service import ( enqueue_action, ensure_credential_access, list_instances, create_instance, sync_instances, batch_instances_action, batch_instances_by_ips, export_instances, export_instances_by_ids, ) from backend.modules.jobs.models import JobItemAction from backend.modules.audit.models import AuditAction, AuditLog, AuditResourceType from backend.modules.instances.constants import AWS_REGIONS from backend.modules.instances import aws_ops from backend.modules.users.models import RoleName, User router = APIRouter(prefix="/api/v1/instances", tags=["instances"]) logger = logging.getLogger(__name__) async def _auth_websocket(websocket: WebSocket, session: AsyncSession) -> AuthUser | None: token = websocket.query_params.get("token") if not token: await websocket.close(code=4401, reason="Missing token") return None try: payload = decode_token(token) except ValueError: await websocket.close(code=4401, reason="Invalid token") return None user_id = payload.get("sub") if not user_id: await websocket.close(code=4401, reason="Invalid token payload") return None user = await session.scalar( select(User).where(User.id == int(user_id)).options(selectinload(User.role), selectinload(User.customer)) ) if not user or not user.is_active: await websocket.close(code=4403, reason="User disabled or not found") return None role_name = payload.get("role") or (user.role.name if user.role else RoleName.CUSTOMER_USER.value) return AuthUser(user=user, role_name=role_name, customer_id=user.customer_id, token=token) @router.websocket("/ws") async def instance_events( websocket: WebSocket, session: AsyncSession = Depends(get_session), ): auth_user = await _auth_websocket(websocket, session) if not auth_user: return await instance_event_manager.connect(websocket, auth_user) try: while True: await websocket.receive_text() except WebSocketDisconnect: await instance_event_manager.disconnect(websocket) except Exception: await instance_event_manager.disconnect(websocket) @router.get("", response_model=InstanceListResponse) async def list_instances_endpoint( filters: InstanceFilterParams = Depends(), session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> InstanceListResponse: items, total = await list_instances(session, filters.model_dump(exclude_none=True), auth_user.user) return InstanceListResponse(items=[InstanceOut.model_validate(i) for i in items], total=total) @router.post("/create", response_model=InstanceCreateResponse, status_code=status.HTTP_201_CREATED) async def create_instance_endpoint( payload: InstanceCreateRequest, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> InstanceCreateResponse: try: result = await create_instance(session, payload.model_dump(), auth_user.user) except HTTPException: raise except Exception as exc: # pragma: no cover logger.exception("create_instance failed") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) return InstanceCreateResponse( **InstanceOut.model_validate(result["instance"]).model_dump(), login_username=result.get("login_username"), login_password=result.get("login_password"), ) @router.post("/batch/action", response_model=BatchInstancesActionOut) async def batch_instances_action_endpoint( payload: BatchInstancesActionIn, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> BatchInstancesActionOut: return await batch_instances_action(session, payload, auth_user.user) @router.post("/batch/by-ips", response_model=BatchInstancesByIpOut) async def batch_instances_by_ips_endpoint( payload: BatchInstancesByIpIn, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> BatchInstancesByIpOut: result = await batch_instances_by_ips(session, payload, auth_user.user) return BatchInstancesByIpOut(**result) async def _action_endpoint( instance_id: int, action: JobItemAction, session: AsyncSession, auth_user: AuthUser, ) -> dict: instance = await session.get(Instance, instance_id) if not instance: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Instance not found") if auth_user.role_name != "ADMIN" and instance.customer_id != auth_user.customer_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") if not instance.credential_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Instance missing credential") await ensure_credential_access(session, instance.credential_id, auth_user.user) job = await enqueue_action(session, instance, action, auth_user.user) session.add( AuditLog( user_id=auth_user.user.id, customer_id=instance.customer_id, action={ JobItemAction.START: AuditAction.INSTANCE_START, JobItemAction.STOP: AuditAction.INSTANCE_STOP, JobItemAction.REBOOT: AuditAction.INSTANCE_REBOOT, JobItemAction.TERMINATE: AuditAction.INSTANCE_TERMINATE, }[action], resource_type=AuditResourceType.INSTANCE, resource_id=instance.id, description=f"{action.value} instance {instance.instance_id}", ) ) await session.commit() return {"job_uuid": job.job_uuid, "job_id": job.id} @router.post("/{instance_id}/start") async def start_instance( instance_id: int, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): return await _action_endpoint(instance_id, JobItemAction.START, session, auth_user) @router.post("/{instance_id}/stop") async def stop_instance( instance_id: int, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): return await _action_endpoint(instance_id, JobItemAction.STOP, session, auth_user) @router.post("/{instance_id}/reboot") async def reboot_instance( instance_id: int, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): return await _action_endpoint(instance_id, JobItemAction.REBOOT, session, auth_user) @router.post("/{instance_id}/terminate") async def terminate_instance( instance_id: int, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): return await _action_endpoint(instance_id, JobItemAction.TERMINATE, session, auth_user) @router.post("/sync") async def sync_instances_endpoint( payload: InstanceSyncRequest, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): job = await sync_instances(session, payload.credential_id, payload.region, auth_user.user, payload.customer_id) return {"job_uuid": job.job_uuid, "job_id": job.id, "total": job.total_count} @router.get("/quota/available") async def available_instance_quota( customer_id: int | None = None, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): # admin can omit customer_id and fall back to global view if auth_user.role_name == "ADMIN": target_customer_id = customer_id or auth_user.customer_id else: target_customer_id = auth_user.customer_id if target_customer_id is None: quota = 999999 # effectively unlimited when not scoped to a customer active_count = await session.scalar( select(func.count(Instance.id)).where( Instance.status != "TERMINATED", Instance.desired_status != "TERMINATED", ) ) available = max(0, quota - (active_count or 0)) return {"quota": quota, "in_use": active_count or 0, "available": available} customer = await session.get(Customer, target_customer_id) if not customer: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found") quota = customer.quota_instances if customer.quota_instances is not None else 999999 active_count = await session.scalar( select(func.count(Instance.id)).where( Instance.customer_id == target_customer_id, Instance.status != "TERMINATED", Instance.desired_status != "TERMINATED", ) ) available = max(0, quota - (active_count or 0)) return {"quota": quota, "in_use": active_count or 0, "available": available} @router.get("/quota/capacity") async def capacity_by_region( credential_id: int, region: str, customer_id: int | None = None, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): # permission + credential check cred = await ensure_credential_access(session, credential_id, auth_user.user) if auth_user.role_name == "ADMIN": target_customer_id = customer_id or auth_user.customer_id else: target_customer_id = auth_user.customer_id if target_customer_id is None: # admin without customer scope: treat as global (only constrained by AWS quota) quota = 999999 active_count = await session.scalar( select(func.count(Instance.id)).where( Instance.status != "TERMINATED", Instance.desired_status != "TERMINATED", Instance.region == region, ) ) available = max(0, quota - (active_count or 0)) else: customer = await session.get(Customer, target_customer_id) if not customer: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found") quota = customer.quota_instances if customer.quota_instances is not None else 999999 active_count = await session.scalar( select(func.count(Instance.id)).where( Instance.customer_id == target_customer_id, Instance.status != "TERMINATED", Instance.desired_status != "TERMINATED", Instance.region == region, ) ) available = max(0, quota - (active_count or 0)) # try to respect AWS regional quota if possible (best-effort) aws_available = None try: resp = await asyncio.to_thread( aws_ops.get_service_quota, cred, region, "ec2", "L-1216C47A" ) # on-demand instances per region if resp and resp.get("Quota", {}).get("Value") is not None: quota_val = int(resp["Quota"]["Value"]) aws_available = max(0, quota_val - (active_count or 0)) available = min(available, aws_available) except Exception: pass return { "quota": quota, "in_use_region": active_count or 0, "available": available, "aws_available": aws_available, "region": region, "credential_id": credential_id, } @router.get("/export") async def export_instances_endpoint( filters: InstanceFilterParams = Depends(), session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): payload = filters.model_dump(exclude_none=True) payload.pop("offset", None) payload.pop("limit", None) content = await export_instances(session, payload, auth_user.user) credential = payload.get("credential_id") or "all" region = payload.get("region") or "all" stamp = datetime.now().strftime("%Y%m%d_%H%M") filename = f"instances_{credential}_{region}_{stamp}.xlsx" return StreamingResponse( BytesIO(content), media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @router.post("/export-by-ids") async def export_by_ids_endpoint( payload: InstanceIdsExportIn, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): content = await export_instances_by_ids(session, payload.instance_ids, auth_user.user) stamp = datetime.now().strftime("%Y%m%d_%H%M") filename = f"instances_selected_{stamp}.xlsx" return StreamingResponse( BytesIO(content), media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @router.get("/meta/aws/regions") async def list_regions( credential_id: int, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): cred = await ensure_credential_access(session, credential_id, auth_user.user) try: resp = await asyncio.to_thread(aws_ops.describe_regions, cred) regions = resp.get("Regions", []) items = [] for r in regions: name = r.get("RegionName") if not name: continue meta = AWS_REGIONS.get(name, {"en": name, "zh": ""}) items.append({"id": name, "label_en": meta["en"], "label_zh": meta["zh"]}) if items: return items except Exception: pass return [ {"id": region, "label_en": meta["en"], "label_zh": meta["zh"]} for region, meta in AWS_REGIONS.items() ] @router.get("/meta/aws/network") async def aws_network( credential_id: int, region: str, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): cred = await ensure_credential_access(session, credential_id, auth_user.user) region_use = region or cred.default_region vpcs = await asyncio.to_thread(aws_ops.describe_vpcs, cred, region_use) subnets = await asyncio.to_thread(aws_ops.describe_subnets, cred, region_use) sgs = await asyncio.to_thread(aws_ops.describe_security_groups, cred, region_use) return { "vpcs": [ { "vpc_id": vpc.get("VpcId"), "name": next((t["Value"] for t in vpc.get("Tags", []) if t.get("Key") == "Name"), None), "cidr": vpc.get("CidrBlock"), } for vpc in vpcs.get("Vpcs", []) ], "subnets": [ { "subnet_id": sn.get("SubnetId"), "vpc_id": sn.get("VpcId"), "az": sn.get("AvailabilityZone"), "cidr": sn.get("CidrBlock"), "name": next((t["Value"] for t in sn.get("Tags", []) if t.get("Key") == "Name"), None), } for sn in subnets.get("Subnets", []) ], "security_groups": [ { "group_id": sg.get("GroupId"), "name": sg.get("GroupName"), "desc": sg.get("Description"), "vpc_id": sg.get("VpcId"), } for sg in sgs.get("SecurityGroups", []) ], } @router.get("/meta/aws/keypairs") async def aws_keypairs( credential_id: int, region: str, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): cred = await ensure_credential_access(session, credential_id, auth_user.user) region_use = region or cred.default_region resp = await asyncio.to_thread(aws_ops.describe_key_pairs, cred, region_use) return [ { "key_name": kp.get("KeyName"), "key_pair_id": kp.get("KeyPairId"), "fingerprint": kp.get("KeyFingerprint"), } for kp in resp.get("KeyPairs", []) ] @router.post("/meta/aws/keypairs", status_code=status.HTTP_201_CREATED) async def aws_create_keypair( credential_id: int = Body(...), region: str = Body(...), key_name: str = Body(...), session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): cred = await ensure_credential_access(session, credential_id, auth_user.user) region_use = region or cred.default_region resp = await asyncio.to_thread(aws_ops.create_key_pair, cred, region_use, key_name) return { "key_name": resp.get("KeyName"), "key_pair_id": resp.get("KeyPairId"), "fingerprint": resp.get("KeyFingerprint"), "material": resp.get("KeyMaterial"), } @router.get("/meta/aws/instance-types") async def aws_instance_types( credential_id: int, region: str, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ): cred = await ensure_credential_access(session, credential_id, auth_user.user) region_use = region or cred.default_region is_admin = auth_user.role_name == "ADMIN" filters = None if is_admin else [{"Name": "instance-type", "Values": ["t3.micro", "t3.small", "t3.medium"]}] resp = await asyncio.to_thread(aws_ops.describe_instance_types, cred, region_use, filters) allowed_customer = {"t3.micro", "t3.small", "t3.medium"} items = [] for it in resp: itype = it.get("InstanceType") if not itype: continue if not is_admin and itype not in allowed_customer: continue vcpu = (it.get("VCpuInfo") or {}).get("DefaultVCpus") mem = (it.get("MemoryInfo") or {}).get("SizeInMiB") net = (it.get("NetworkInfo") or {}).get("NetworkPerformance") items.append( { "instance_type": itype, "vcpu": vcpu, "memory_mib": mem, "network_performance": net, } ) return items