feat: init_db自动执行SQL迁移脚本
数据库初始化时自动执行sql目录下的迁移脚本,支持幂等执行: - 遍历sql/*.sql文件按顺序执行 - 忽略重复字段/索引错误(MySQL 1060/1061) - 输出执行的迁移文件列表 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pymysql
|
import pymysql
|
||||||
|
from pymysql import MySQLError
|
||||||
|
|
||||||
from app import create_app
|
from app import create_app
|
||||||
from app.extensions import db
|
from app.extensions import db
|
||||||
@@ -12,6 +13,7 @@ from app.models import DetectionConfig, SpamTrainingSample, User
|
|||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
|
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
|
||||||
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
|
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
|
||||||
|
SQL_MIGRATIONS_DIR = BASE_DIR / "sql"
|
||||||
|
|
||||||
|
|
||||||
def load_mysql_cfg() -> dict:
|
def load_mysql_cfg() -> dict:
|
||||||
@@ -39,6 +41,46 @@ def create_database(mysql_cfg: dict) -> None:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_sql_migrations(mysql_cfg: dict) -> list[str]:
|
||||||
|
"""执行 sql 目录下的迁移脚本"""
|
||||||
|
if not SQL_MIGRATIONS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
conn = pymysql.connect(
|
||||||
|
host=mysql_cfg.get("host", "127.0.0.1"),
|
||||||
|
port=int(mysql_cfg.get("port", 3306)),
|
||||||
|
user=mysql_cfg["user"],
|
||||||
|
password=mysql_cfg["password"],
|
||||||
|
database=mysql_cfg["database"],
|
||||||
|
charset=mysql_cfg.get("charset", "utf8mb4"),
|
||||||
|
autocommit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
executed = []
|
||||||
|
try:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
sql_files = sorted(SQL_MIGRATIONS_DIR.glob("*.sql"))
|
||||||
|
for sql_file in sql_files:
|
||||||
|
sql_content = sql_file.read_text(encoding="utf-8")
|
||||||
|
statements = [s.strip() for s in sql_content.split(";") if s.strip() and not s.strip().startswith("--")]
|
||||||
|
for stmt in statements:
|
||||||
|
if stmt:
|
||||||
|
try:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
except MySQLError as e:
|
||||||
|
if "1060" in str(e):
|
||||||
|
pass
|
||||||
|
elif "1061" in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print(f"SQL 警告 ({sql_file.name}): {e}")
|
||||||
|
executed.append(sql_file.name)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return executed
|
||||||
|
|
||||||
|
|
||||||
def ensure_seed_file() -> None:
|
def ensure_seed_file() -> None:
|
||||||
if SPAM_SEED_PATH.exists():
|
if SPAM_SEED_PATH.exists():
|
||||||
return
|
return
|
||||||
@@ -135,6 +177,7 @@ def main():
|
|||||||
app = create_app()
|
app = create_app()
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
db.create_all()
|
db.create_all()
|
||||||
|
migrations = run_sql_migrations(mysql_cfg)
|
||||||
created, updated = seed_samples()
|
created, updated = seed_samples()
|
||||||
threshold = ensure_detection_config(mysql_cfg)
|
threshold = ensure_detection_config(mysql_cfg)
|
||||||
admin_msg = init_admin(mysql_cfg)
|
admin_msg = init_admin(mysql_cfg)
|
||||||
@@ -146,6 +189,8 @@ def main():
|
|||||||
print(f"- 初始阈值: {threshold}")
|
print(f"- 初始阈值: {threshold}")
|
||||||
print(f"- {admin_msg}")
|
print(f"- {admin_msg}")
|
||||||
print(f"- 模型版本: {model_meta.get('version')}")
|
print(f"- 模型版本: {model_meta.get('version')}")
|
||||||
|
if migrations:
|
||||||
|
print(f"- SQL迁移: {', '.join(migrations)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user