import os from contextlib import contextmanager from datetime import datetime, timedelta, timezone from typing import Iterable, Optional, List, Dict from sqlalchemy import Column, DateTime, Integer, String, Float, create_engine, select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import declarative_base, sessionmaker DATABASE_URL = os.getenv("DATABASE_URL", "mysql+pymysql://username:password@localhost:3306/ip_ops") engine = create_engine(DATABASE_URL, pool_pre_ping=True) SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) Base = declarative_base() class IPOperation(Base): __tablename__ = "ip_operations" id = Column(Integer, primary_key=True, autoincrement=True) ip_address = Column(String(64), unique=True, nullable=False, index=True) note = Column(String(255), nullable=True) class IPAccountMapping(Base): __tablename__ = "ip_account_mapping" id = Column(Integer, primary_key=True, autoincrement=True) ip_address = Column(String(64), unique=True, nullable=False, index=True) account_name = Column(String(128), nullable=False) class ServerSpec(Base): __tablename__ = "server_specs" id = Column(Integer, primary_key=True, autoincrement=True) ip_address = Column(String(64), unique=True, nullable=False, index=True) account_name = Column(String(128), nullable=False) instance_type = Column(String(64), nullable=True) instance_name = Column(String(255), nullable=True) volume_type = Column(String(64), nullable=True) security_group_names = Column(String(512), nullable=True) security_group_ids = Column(String(512), nullable=True) region = Column(String(64), nullable=True) subnet_id = Column(String(128), nullable=True) availability_zone = Column(String(64), nullable=True) created_at = Column(DateTime(timezone=True), nullable=False) class IPReplacementHistory(Base): __tablename__ = "ip_replacement_history" id = Column(Integer, primary_key=True, autoincrement=True) old_ip = Column(String(64), nullable=False, index=True) new_ip = Column(String(64), nullable=False, index=True) account_name = Column(String(128), nullable=False) group_id = Column(String(128), nullable=True, index=True) terminated_network_out_mb = Column(Float, nullable=True) created_at = Column(DateTime(timezone=True), nullable=False) def resolve_group_id(old_ip: str) -> str: """Group id继承上一条 new_ip=old_ip 的记录,否则用 old_ip 作为新的组标识。""" with db_session() as session: prev = session.scalar( select(IPReplacementHistory.group_id) .where(IPReplacementHistory.new_ip == old_ip) .order_by(IPReplacementHistory.id.desc()) ) return prev or old_ip def init_db() -> None: Base.metadata.create_all(bind=engine) @contextmanager def db_session(): session = SessionLocal() try: yield session session.commit() except SQLAlchemyError: session.rollback() raise finally: session.close() def load_disallowed_ips() -> set[str]: with db_session() as session: rows: Iterable[IPOperation] = session.scalars(select(IPOperation.ip_address)) return {row for row in rows} def get_account_by_ip(ip: str) -> Optional[str]: with db_session() as session: return session.scalar( select(IPAccountMapping.account_name).where(IPAccountMapping.ip_address == ip) ) def update_ip_account_mapping(old_ip: str, new_ip: str, account_name: str) -> None: with db_session() as session: existing_mapping = session.scalar( select(IPAccountMapping).where(IPAccountMapping.ip_address == old_ip) ) conflict_mapping = session.scalar( select(IPAccountMapping).where(IPAccountMapping.ip_address == new_ip) ) if conflict_mapping and (not existing_mapping or conflict_mapping.id != existing_mapping.id): raise ValueError(f"IP {new_ip} 已经映射到账户 {conflict_mapping.account_name}") if existing_mapping: existing_mapping.ip_address = new_ip existing_mapping.account_name = account_name else: session.add(IPAccountMapping(ip_address=new_ip, account_name=account_name)) def _now_cn() -> datetime: return datetime.now(timezone(timedelta(hours=8))) def upsert_server_spec( *, ip_address: str, account_name: str, instance_type: Optional[str], instance_name: Optional[str], volume_type: Optional[str], security_group_names: List[str], security_group_ids: List[str], region: Optional[str], subnet_id: Optional[str], availability_zone: Optional[str], created_at: Optional[datetime] = None, ) -> None: with db_session() as session: spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address)) payload = { "account_name": account_name, "instance_type": instance_type, "instance_name": instance_name, "volume_type": volume_type, "security_group_names": ",".join(security_group_names), "security_group_ids": ",".join(security_group_ids), "region": region, "subnet_id": subnet_id, "availability_zone": availability_zone, "created_at": created_at or _now_cn(), } if spec: for key, val in payload.items(): setattr(spec, key, val) else: session.add(ServerSpec(ip_address=ip_address, **payload)) def get_server_spec(ip_address: str) -> Optional[Dict[str, Optional[str]]]: with db_session() as session: spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address)) if not spec: return None return { "ip_address": spec.ip_address, "account_name": spec.account_name, "instance_type": spec.instance_type, "instance_name": spec.instance_name, "volume_type": spec.volume_type, "security_group_names": spec.security_group_names.split(",") if spec.security_group_names else [], "security_group_ids": spec.security_group_ids.split(",") if spec.security_group_ids else [], "region": spec.region, "subnet_id": spec.subnet_id, "availability_zone": spec.availability_zone, "created_at": spec.created_at, } def add_replacement_history( old_ip: str, new_ip: str, account_name: str, group_id: Optional[str], terminated_network_out_mb: Optional[float] = None, ) -> None: resolved_group = group_id or resolve_group_id(old_ip) with db_session() as session: session.add( IPReplacementHistory( old_ip=old_ip, new_ip=new_ip, account_name=account_name, group_id=resolved_group, terminated_network_out_mb=terminated_network_out_mb, created_at=_now_cn(), ) ) def get_replacement_history(limit: int = 50) -> List[Dict[str, str]]: with db_session() as session: rows: Iterable[IPReplacementHistory] = session.scalars( select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit) ) return [ { "old_ip": row.old_ip, "new_ip": row.new_ip, "account_name": row.account_name, "group_id": row.group_id, "terminated_network_out_mb": row.terminated_network_out_mb, "created_at": row.created_at.isoformat(), } for row in rows ] def get_history_by_ip_or_group(ip: Optional[str], group_id: Optional[str], limit: int = 200) -> List[Dict[str, str]]: with db_session() as session: stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit) if group_id: stmt = stmt.where(IPReplacementHistory.group_id == group_id) elif ip: stmt = stmt.where( (IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip) ) rows: Iterable[IPReplacementHistory] = session.scalars(stmt) return [ { "old_ip": row.old_ip, "new_ip": row.new_ip, "account_name": row.account_name, "group_id": row.group_id, "terminated_network_out_mb": row.terminated_network_out_mb, "created_at": row.created_at.isoformat(), } for row in rows ] def get_history_chains(ip: Optional[str] = None, group_id: Optional[str] = None, limit: int = 500) -> List[Dict[str, object]]: """返回按 group_id 聚合的链路信息(按创建时间升序构建链)。""" with db_session() as session: stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.created_at.asc()) if group_id: stmt = stmt.where(IPReplacementHistory.group_id == group_id) elif ip: stmt = stmt.where( (IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip) ) stmt = stmt.limit(limit) rows: Iterable[IPReplacementHistory] = session.scalars(stmt) groups: Dict[str, Dict[str, object]] = {} for row in rows: gid = row.group_id or row.old_ip if gid not in groups: groups[gid] = {"group_id": gid, "items": [], "chain": [], "first_ip_start": None} entry = { "old_ip": row.old_ip, "new_ip": row.new_ip, "account_name": row.account_name, "terminated_network_out_mb": row.terminated_network_out_mb, "created_at": row.created_at.isoformat(), } groups[gid]["items"].append(entry) # 构建链路 for gid, data in groups.items(): items = data["items"] items.sort(key=lambda x: x["created_at"]) chain: List[str] = [] for it in items: if not chain: chain.append(it["old_ip"]) if chain[-1] != it["old_ip"] and it["old_ip"] not in chain: chain.append(it["old_ip"]) if chain[-1] != it["new_ip"]: chain.append(it["new_ip"]) data["chain"] = chain # 读取链首 IP 的创建时间(server_specs.created_at) if chain: first_ip = chain[0] spec_time = session.scalar( select(ServerSpec.created_at).where(ServerSpec.ip_address == first_ip) ) if spec_time: data["first_ip_start"] = spec_time.isoformat() # 返回按最早时间排序的组 ordered = sorted( groups.values(), key=lambda g: g["items"][0]["created_at"] if g["items"] else "", ) return ordered