aws-mt5/db.py

297 lines
11 KiB
Python
Raw Normal View History

2026-01-04 18:58:20 +08:00
import os
from contextlib import contextmanager
2026-01-05 11:07:55 +08:00
from datetime import datetime, timedelta, timezone
from typing import Iterable, Optional, List, Dict
2026-01-04 18:58:20 +08:00
2026-01-05 11:07:55 +08:00
from sqlalchemy import Column, DateTime, Integer, String, Float, create_engine, select
2026-01-04 18:58:20 +08:00
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import declarative_base, sessionmaker
2026-01-05 11:07:55 +08:00
DATABASE_URL = os.getenv("DATABASE_URL", "mysql+pymysql://username:password@localhost:3306/ip_ops")
2026-01-04 18:58:20 +08:00
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
Base = declarative_base()
class IPOperation(Base):
__tablename__ = "ip_operations"
id = Column(Integer, primary_key=True, autoincrement=True)
ip_address = Column(String(64), unique=True, nullable=False, index=True)
note = Column(String(255), nullable=True)
2026-01-05 11:07:55 +08:00
class IPAccountMapping(Base):
__tablename__ = "ip_account_mapping"
id = Column(Integer, primary_key=True, autoincrement=True)
ip_address = Column(String(64), unique=True, nullable=False, index=True)
account_name = Column(String(128), nullable=False)
class ServerSpec(Base):
__tablename__ = "server_specs"
id = Column(Integer, primary_key=True, autoincrement=True)
ip_address = Column(String(64), unique=True, nullable=False, index=True)
account_name = Column(String(128), nullable=False)
instance_type = Column(String(64), nullable=True)
instance_name = Column(String(255), nullable=True)
volume_type = Column(String(64), nullable=True)
security_group_names = Column(String(512), nullable=True)
security_group_ids = Column(String(512), nullable=True)
region = Column(String(64), nullable=True)
subnet_id = Column(String(128), nullable=True)
availability_zone = Column(String(64), nullable=True)
created_at = Column(DateTime(timezone=True), nullable=False)
class IPReplacementHistory(Base):
__tablename__ = "ip_replacement_history"
id = Column(Integer, primary_key=True, autoincrement=True)
old_ip = Column(String(64), nullable=False, index=True)
new_ip = Column(String(64), nullable=False, index=True)
account_name = Column(String(128), nullable=False)
group_id = Column(String(128), nullable=True, index=True)
terminated_network_out_mb = Column(Float, nullable=True)
created_at = Column(DateTime(timezone=True), nullable=False)
def resolve_group_id(old_ip: str) -> str:
"""Group id继承上一条 new_ip=old_ip 的记录,否则用 old_ip 作为新的组标识。"""
with db_session() as session:
prev = session.scalar(
select(IPReplacementHistory.group_id)
.where(IPReplacementHistory.new_ip == old_ip)
.order_by(IPReplacementHistory.id.desc())
)
return prev or old_ip
2026-01-04 18:58:20 +08:00
def init_db() -> None:
Base.metadata.create_all(bind=engine)
@contextmanager
def db_session():
session = SessionLocal()
try:
yield session
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
def load_disallowed_ips() -> set[str]:
with db_session() as session:
rows: Iterable[IPOperation] = session.scalars(select(IPOperation.ip_address))
return {row for row in rows}
2026-01-05 11:07:55 +08:00
def get_account_by_ip(ip: str) -> Optional[str]:
with db_session() as session:
return session.scalar(
select(IPAccountMapping.account_name).where(IPAccountMapping.ip_address == ip)
)
def update_ip_account_mapping(old_ip: str, new_ip: str, account_name: str) -> None:
with db_session() as session:
existing_mapping = session.scalar(
select(IPAccountMapping).where(IPAccountMapping.ip_address == old_ip)
)
conflict_mapping = session.scalar(
select(IPAccountMapping).where(IPAccountMapping.ip_address == new_ip)
)
if conflict_mapping and (not existing_mapping or conflict_mapping.id != existing_mapping.id):
raise ValueError(f"IP {new_ip} 已经映射到账户 {conflict_mapping.account_name}")
if existing_mapping:
existing_mapping.ip_address = new_ip
existing_mapping.account_name = account_name
else:
session.add(IPAccountMapping(ip_address=new_ip, account_name=account_name))
def _now_cn() -> datetime:
return datetime.now(timezone(timedelta(hours=8)))
def upsert_server_spec(
*,
ip_address: str,
account_name: str,
instance_type: Optional[str],
instance_name: Optional[str],
volume_type: Optional[str],
security_group_names: List[str],
security_group_ids: List[str],
region: Optional[str],
subnet_id: Optional[str],
availability_zone: Optional[str],
created_at: Optional[datetime] = None,
) -> None:
with db_session() as session:
spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address))
payload = {
"account_name": account_name,
"instance_type": instance_type,
"instance_name": instance_name,
"volume_type": volume_type,
"security_group_names": ",".join(security_group_names),
"security_group_ids": ",".join(security_group_ids),
"region": region,
"subnet_id": subnet_id,
"availability_zone": availability_zone,
"created_at": created_at or _now_cn(),
}
if spec:
for key, val in payload.items():
setattr(spec, key, val)
else:
session.add(ServerSpec(ip_address=ip_address, **payload))
def get_server_spec(ip_address: str) -> Optional[Dict[str, Optional[str]]]:
with db_session() as session:
spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address))
if not spec:
return None
return {
"ip_address": spec.ip_address,
"account_name": spec.account_name,
"instance_type": spec.instance_type,
"instance_name": spec.instance_name,
"volume_type": spec.volume_type,
"security_group_names": spec.security_group_names.split(",") if spec.security_group_names else [],
"security_group_ids": spec.security_group_ids.split(",") if spec.security_group_ids else [],
"region": spec.region,
"subnet_id": spec.subnet_id,
"availability_zone": spec.availability_zone,
"created_at": spec.created_at,
}
def add_replacement_history(
old_ip: str,
new_ip: str,
account_name: str,
group_id: Optional[str],
terminated_network_out_mb: Optional[float] = None,
) -> None:
resolved_group = group_id or resolve_group_id(old_ip)
with db_session() as session:
session.add(
IPReplacementHistory(
old_ip=old_ip,
new_ip=new_ip,
account_name=account_name,
group_id=resolved_group,
terminated_network_out_mb=terminated_network_out_mb,
created_at=_now_cn(),
)
)
def get_replacement_history(limit: int = 50) -> List[Dict[str, str]]:
with db_session() as session:
rows: Iterable[IPReplacementHistory] = session.scalars(
select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit)
)
return [
{
"old_ip": row.old_ip,
"new_ip": row.new_ip,
"account_name": row.account_name,
"group_id": row.group_id,
"terminated_network_out_mb": row.terminated_network_out_mb,
"created_at": row.created_at.isoformat(),
}
for row in rows
]
def get_history_by_ip_or_group(ip: Optional[str], group_id: Optional[str], limit: int = 200) -> List[Dict[str, str]]:
with db_session() as session:
stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit)
if group_id:
stmt = stmt.where(IPReplacementHistory.group_id == group_id)
elif ip:
stmt = stmt.where(
(IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip)
)
rows: Iterable[IPReplacementHistory] = session.scalars(stmt)
return [
{
"old_ip": row.old_ip,
"new_ip": row.new_ip,
"account_name": row.account_name,
"group_id": row.group_id,
"terminated_network_out_mb": row.terminated_network_out_mb,
"created_at": row.created_at.isoformat(),
}
for row in rows
]
def get_history_chains(ip: Optional[str] = None, group_id: Optional[str] = None, limit: int = 500) -> List[Dict[str, object]]:
"""返回按 group_id 聚合的链路信息(按创建时间升序构建链)。"""
with db_session() as session:
stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.created_at.asc())
if group_id:
stmt = stmt.where(IPReplacementHistory.group_id == group_id)
elif ip:
stmt = stmt.where(
(IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip)
)
stmt = stmt.limit(limit)
rows: Iterable[IPReplacementHistory] = session.scalars(stmt)
groups: Dict[str, Dict[str, object]] = {}
for row in rows:
gid = row.group_id or row.old_ip
if gid not in groups:
groups[gid] = {"group_id": gid, "items": [], "chain": [], "first_ip_start": None}
entry = {
"old_ip": row.old_ip,
"new_ip": row.new_ip,
"account_name": row.account_name,
"terminated_network_out_mb": row.terminated_network_out_mb,
"created_at": row.created_at.isoformat(),
}
groups[gid]["items"].append(entry)
# 构建链路
for gid, data in groups.items():
items = data["items"]
items.sort(key=lambda x: x["created_at"])
chain: List[str] = []
for it in items:
if not chain:
chain.append(it["old_ip"])
if chain[-1] != it["old_ip"] and it["old_ip"] not in chain:
chain.append(it["old_ip"])
if chain[-1] != it["new_ip"]:
chain.append(it["new_ip"])
data["chain"] = chain
# 读取链首 IP 的创建时间server_specs.created_at
if chain:
first_ip = chain[0]
spec_time = session.scalar(
select(ServerSpec.created_at).where(ServerSpec.ip_address == first_ip)
)
if spec_time:
data["first_ip_start"] = spec_time.isoformat()
# 返回按最早时间排序的组
ordered = sorted(
groups.values(),
key=lambda g: g["items"][0]["created_at"] if g["items"] else "",
)
return ordered