From 38cb9345d6000be0f6999b971f54714bbd955ef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=AD=A3=E8=88=AA?= <1915581435@qq.com> Date: Wed, 22 Apr 2026 21:52:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20init=5Fdb=E8=87=AA=E5=8A=A8=E6=89=A7?= =?UTF-8?q?=E8=A1=8CSQL=E8=BF=81=E7=A7=BB=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 数据库初始化时自动执行sql目录下的迁移脚本,支持幂等执行: - 遍历sql/*.sql文件按顺序执行 - 忽略重复字段/索引错误(MySQL 1060/1061) - 输出执行的迁移文件列表 Co-Authored-By: Claude Opus 4.6 --- backend/init_db.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/backend/init_db.py b/backend/init_db.py index e17c0f9..9e355eb 100644 --- a/backend/init_db.py +++ b/backend/init_db.py @@ -2,6 +2,7 @@ from pathlib import Path import pymysql +from pymysql import MySQLError from app import create_app from app.extensions import db @@ -12,6 +13,7 @@ from app.models import DetectionConfig, SpamTrainingSample, User BASE_DIR = Path(__file__).resolve().parent MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json" SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json" +SQL_MIGRATIONS_DIR = BASE_DIR / "sql" def load_mysql_cfg() -> dict: @@ -39,6 +41,46 @@ def create_database(mysql_cfg: dict) -> None: conn.close() +def run_sql_migrations(mysql_cfg: dict) -> list[str]: + """执行 sql 目录下的迁移脚本""" + if not SQL_MIGRATIONS_DIR.exists(): + return [] + + conn = pymysql.connect( + host=mysql_cfg.get("host", "127.0.0.1"), + port=int(mysql_cfg.get("port", 3306)), + user=mysql_cfg["user"], + password=mysql_cfg["password"], + database=mysql_cfg["database"], + charset=mysql_cfg.get("charset", "utf8mb4"), + autocommit=True, + ) + + executed = [] + try: + with conn.cursor() as cursor: + sql_files = sorted(SQL_MIGRATIONS_DIR.glob("*.sql")) + for sql_file in sql_files: + sql_content = sql_file.read_text(encoding="utf-8") + statements = [s.strip() for s in sql_content.split(";") if s.strip() and not s.strip().startswith("--")] + for stmt in statements: + if stmt: + try: + cursor.execute(stmt) + except MySQLError as e: + if "1060" in str(e): + pass + elif "1061" in str(e): + pass + else: + print(f"SQL 警告 ({sql_file.name}): {e}") + executed.append(sql_file.name) + finally: + conn.close() + + return executed + + def ensure_seed_file() -> None: if SPAM_SEED_PATH.exists(): return @@ -135,6 +177,7 @@ def main(): app = create_app() with app.app_context(): db.create_all() + migrations = run_sql_migrations(mysql_cfg) created, updated = seed_samples() threshold = ensure_detection_config(mysql_cfg) admin_msg = init_admin(mysql_cfg) @@ -146,6 +189,8 @@ def main(): print(f"- 初始阈值: {threshold}") print(f"- {admin_msg}") print(f"- 模型版本: {model_meta.get('version')}") + if migrations: + print(f"- SQL迁移: {', '.join(migrations)}") if __name__ == "__main__":