diff --git a/backend/init_db.py b/backend/init_db.py index e17c0f9..9e355eb 100644 --- a/backend/init_db.py +++ b/backend/init_db.py @@ -2,6 +2,7 @@ from pathlib import Path import pymysql +from pymysql import MySQLError from app import create_app from app.extensions import db @@ -12,6 +13,7 @@ from app.models import DetectionConfig, SpamTrainingSample, User BASE_DIR = Path(__file__).resolve().parent MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json" SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json" +SQL_MIGRATIONS_DIR = BASE_DIR / "sql" def load_mysql_cfg() -> dict: @@ -39,6 +41,46 @@ def create_database(mysql_cfg: dict) -> None: conn.close() +def run_sql_migrations(mysql_cfg: dict) -> list[str]: + """执行 sql 目录下的迁移脚本""" + if not SQL_MIGRATIONS_DIR.exists(): + return [] + + conn = pymysql.connect( + host=mysql_cfg.get("host", "127.0.0.1"), + port=int(mysql_cfg.get("port", 3306)), + user=mysql_cfg["user"], + password=mysql_cfg["password"], + database=mysql_cfg["database"], + charset=mysql_cfg.get("charset", "utf8mb4"), + autocommit=True, + ) + + executed = [] + try: + with conn.cursor() as cursor: + sql_files = sorted(SQL_MIGRATIONS_DIR.glob("*.sql")) + for sql_file in sql_files: + sql_content = sql_file.read_text(encoding="utf-8") + statements = [s.strip() for s in sql_content.split(";") if s.strip() and not s.strip().startswith("--")] + for stmt in statements: + if stmt: + try: + cursor.execute(stmt) + except MySQLError as e: + if "1060" in str(e): + pass + elif "1061" in str(e): + pass + else: + print(f"SQL 警告 ({sql_file.name}): {e}") + executed.append(sql_file.name) + finally: + conn.close() + + return executed + + def ensure_seed_file() -> None: if SPAM_SEED_PATH.exists(): return @@ -135,6 +177,7 @@ def main(): app = create_app() with app.app_context(): db.create_all() + migrations = run_sql_migrations(mysql_cfg) created, updated = seed_samples() threshold = ensure_detection_config(mysql_cfg) admin_msg = init_admin(mysql_cfg) @@ -146,6 +189,8 @@ def main(): print(f"- 初始阈值: {threshold}") print(f"- {admin_msg}") print(f"- 模型版本: {model_meta.get('version')}") + if migrations: + print(f"- SQL迁移: {', '.join(migrations)}") if __name__ == "__main__":