feat: init_db自动执行SQL迁移脚本

数据库初始化时自动执行sql目录下的迁移脚本,支持幂等执行:
- 遍历sql/*.sql文件按顺序执行
- 忽略重复字段/索引错误(MySQL 1060/1061)
- 输出执行的迁移文件列表

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
刘正航
2026-04-22 21:52:25 +08:00
parent 2dcd7ce9f6
commit 38cb9345d6

View File

@@ -2,6 +2,7 @@
from pathlib import Path from pathlib import Path
import pymysql import pymysql
from pymysql import MySQLError
from app import create_app from app import create_app
from app.extensions import db from app.extensions import db
@@ -12,6 +13,7 @@ from app.models import DetectionConfig, SpamTrainingSample, User
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json" MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json" SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
SQL_MIGRATIONS_DIR = BASE_DIR / "sql"
def load_mysql_cfg() -> dict: def load_mysql_cfg() -> dict:
@@ -39,6 +41,46 @@ def create_database(mysql_cfg: dict) -> None:
conn.close() conn.close()
def run_sql_migrations(mysql_cfg: dict) -> list[str]:
"""执行 sql 目录下的迁移脚本"""
if not SQL_MIGRATIONS_DIR.exists():
return []
conn = pymysql.connect(
host=mysql_cfg.get("host", "127.0.0.1"),
port=int(mysql_cfg.get("port", 3306)),
user=mysql_cfg["user"],
password=mysql_cfg["password"],
database=mysql_cfg["database"],
charset=mysql_cfg.get("charset", "utf8mb4"),
autocommit=True,
)
executed = []
try:
with conn.cursor() as cursor:
sql_files = sorted(SQL_MIGRATIONS_DIR.glob("*.sql"))
for sql_file in sql_files:
sql_content = sql_file.read_text(encoding="utf-8")
statements = [s.strip() for s in sql_content.split(";") if s.strip() and not s.strip().startswith("--")]
for stmt in statements:
if stmt:
try:
cursor.execute(stmt)
except MySQLError as e:
if "1060" in str(e):
pass
elif "1061" in str(e):
pass
else:
print(f"SQL 警告 ({sql_file.name}): {e}")
executed.append(sql_file.name)
finally:
conn.close()
return executed
def ensure_seed_file() -> None: def ensure_seed_file() -> None:
if SPAM_SEED_PATH.exists(): if SPAM_SEED_PATH.exists():
return return
@@ -135,6 +177,7 @@ def main():
app = create_app() app = create_app()
with app.app_context(): with app.app_context():
db.create_all() db.create_all()
migrations = run_sql_migrations(mysql_cfg)
created, updated = seed_samples() created, updated = seed_samples()
threshold = ensure_detection_config(mysql_cfg) threshold = ensure_detection_config(mysql_cfg)
admin_msg = init_admin(mysql_cfg) admin_msg = init_admin(mysql_cfg)
@@ -146,6 +189,8 @@ def main():
print(f"- 初始阈值: {threshold}") print(f"- 初始阈值: {threshold}")
print(f"- {admin_msg}") print(f"- {admin_msg}")
print(f"- 模型版本: {model_meta.get('version')}") print(f"- 模型版本: {model_meta.get('version')}")
if migrations:
print(f"- SQL迁移: {', '.join(migrations)}")
if __name__ == "__main__": if __name__ == "__main__":