feat: init_db自动执行SQL迁移脚本
数据库初始化时自动执行sql目录下的迁移脚本,支持幂等执行: - 遍历sql/*.sql文件按顺序执行 - 忽略重复字段/索引错误(MySQL 1060/1061) - 输出执行的迁移文件列表 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pymysql
|
||||
from pymysql import MySQLError
|
||||
|
||||
from app import create_app
|
||||
from app.extensions import db
|
||||
@@ -12,6 +13,7 @@ from app.models import DetectionConfig, SpamTrainingSample, User
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
|
||||
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
|
||||
SQL_MIGRATIONS_DIR = BASE_DIR / "sql"
|
||||
|
||||
|
||||
def load_mysql_cfg() -> dict:
|
||||
@@ -39,6 +41,46 @@ def create_database(mysql_cfg: dict) -> None:
|
||||
conn.close()
|
||||
|
||||
|
||||
def run_sql_migrations(mysql_cfg: dict) -> list[str]:
|
||||
"""执行 sql 目录下的迁移脚本"""
|
||||
if not SQL_MIGRATIONS_DIR.exists():
|
||||
return []
|
||||
|
||||
conn = pymysql.connect(
|
||||
host=mysql_cfg.get("host", "127.0.0.1"),
|
||||
port=int(mysql_cfg.get("port", 3306)),
|
||||
user=mysql_cfg["user"],
|
||||
password=mysql_cfg["password"],
|
||||
database=mysql_cfg["database"],
|
||||
charset=mysql_cfg.get("charset", "utf8mb4"),
|
||||
autocommit=True,
|
||||
)
|
||||
|
||||
executed = []
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
sql_files = sorted(SQL_MIGRATIONS_DIR.glob("*.sql"))
|
||||
for sql_file in sql_files:
|
||||
sql_content = sql_file.read_text(encoding="utf-8")
|
||||
statements = [s.strip() for s in sql_content.split(";") if s.strip() and not s.strip().startswith("--")]
|
||||
for stmt in statements:
|
||||
if stmt:
|
||||
try:
|
||||
cursor.execute(stmt)
|
||||
except MySQLError as e:
|
||||
if "1060" in str(e):
|
||||
pass
|
||||
elif "1061" in str(e):
|
||||
pass
|
||||
else:
|
||||
print(f"SQL 警告 ({sql_file.name}): {e}")
|
||||
executed.append(sql_file.name)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return executed
|
||||
|
||||
|
||||
def ensure_seed_file() -> None:
|
||||
if SPAM_SEED_PATH.exists():
|
||||
return
|
||||
@@ -135,6 +177,7 @@ def main():
|
||||
app = create_app()
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
migrations = run_sql_migrations(mysql_cfg)
|
||||
created, updated = seed_samples()
|
||||
threshold = ensure_detection_config(mysql_cfg)
|
||||
admin_msg = init_admin(mysql_cfg)
|
||||
@@ -146,6 +189,8 @@ def main():
|
||||
print(f"- 初始阈值: {threshold}")
|
||||
print(f"- {admin_msg}")
|
||||
print(f"- 模型版本: {model_meta.get('version')}")
|
||||
if migrations:
|
||||
print(f"- SQL迁移: {', '.join(migrations)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user