Files
c/backend/init_db.py
刘正航 38cb9345d6 feat: init_db自动执行SQL迁移脚本
数据库初始化时自动执行sql目录下的迁移脚本,支持幂等执行:
- 遍历sql/*.sql文件按顺序执行
- 忽略重复字段/索引错误(MySQL 1060/1061)
- 输出执行的迁移文件列表

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-22 21:52:25 +08:00

198 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from pathlib import Path
import pymysql
from pymysql import MySQLError
from app import create_app
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import DetectionConfig, SpamTrainingSample, User
BASE_DIR = Path(__file__).resolve().parent
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
SQL_MIGRATIONS_DIR = BASE_DIR / "sql"
def load_mysql_cfg() -> dict:
if not MYSQL_CONFIG_PATH.exists():
raise FileNotFoundError(f"未找到配置文件: {MYSQL_CONFIG_PATH}")
with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file:
return json.load(file)
def create_database(mysql_cfg: dict) -> None:
conn = pymysql.connect(
host=mysql_cfg.get("host", "127.0.0.1"),
port=int(mysql_cfg.get("port", 3306)),
user=mysql_cfg["user"],
password=mysql_cfg["password"],
charset=mysql_cfg.get("charset", "utf8mb4"),
autocommit=True,
)
try:
with conn.cursor() as cursor:
cursor.execute(
f"CREATE DATABASE IF NOT EXISTS `{mysql_cfg['database']}` DEFAULT CHARACTER SET {mysql_cfg.get('charset', 'utf8mb4')}"
)
finally:
conn.close()
def run_sql_migrations(mysql_cfg: dict) -> list[str]:
"""执行 sql 目录下的迁移脚本"""
if not SQL_MIGRATIONS_DIR.exists():
return []
conn = pymysql.connect(
host=mysql_cfg.get("host", "127.0.0.1"),
port=int(mysql_cfg.get("port", 3306)),
user=mysql_cfg["user"],
password=mysql_cfg["password"],
database=mysql_cfg["database"],
charset=mysql_cfg.get("charset", "utf8mb4"),
autocommit=True,
)
executed = []
try:
with conn.cursor() as cursor:
sql_files = sorted(SQL_MIGRATIONS_DIR.glob("*.sql"))
for sql_file in sql_files:
sql_content = sql_file.read_text(encoding="utf-8")
statements = [s.strip() for s in sql_content.split(";") if s.strip() and not s.strip().startswith("--")]
for stmt in statements:
if stmt:
try:
cursor.execute(stmt)
except MySQLError as e:
if "1060" in str(e):
pass
elif "1061" in str(e):
pass
else:
print(f"SQL 警告 ({sql_file.name}): {e}")
executed.append(sql_file.name)
finally:
conn.close()
return executed
def ensure_seed_file() -> None:
if SPAM_SEED_PATH.exists():
return
SPAM_SEED_PATH.parent.mkdir(parents=True, exist_ok=True)
defaults = [
{"text": "点击链接领取1000元现金红包先到先得", "label": "spam"},
{"text": "今晚项目例会改到19点记得带周报", "label": "ham"},
{"text": "恭喜你成为幸运用户,马上回复领取大奖", "label": "spam"},
{"text": "明天上午10点客户演示请提前准备材料", "label": "ham"},
]
with SPAM_SEED_PATH.open("w", encoding="utf-8") as file:
json.dump(defaults, file, ensure_ascii=False, indent=2)
def seed_samples() -> tuple[int, int]:
ensure_seed_file()
with SPAM_SEED_PATH.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
created = 0
updated = 0
for item in rows:
text = " ".join((item.get("text") or "").strip().split())
label = (item.get("label") or "").strip().lower()
if len(text) < 2 or label not in {"spam", "ham"}:
continue
sample = SpamTrainingSample.query.filter_by(text=text).first()
if sample:
sample.label = label
sample.is_active = True
sample.source = sample.source or "seed"
updated += 1
continue
sample = SpamTrainingSample(text=text, label=label, source="seed", is_active=True)
db.session.add(sample)
created += 1
db.session.commit()
return created, updated
def ensure_detection_config(mysql_cfg: dict) -> float:
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
if cfg:
return float(cfg.spam_threshold)
init_cfg = mysql_cfg.get("detection_init", {}) if isinstance(mysql_cfg, dict) else {}
try:
threshold = float(init_cfg.get("spam_threshold", 0.75))
except Exception:
threshold = 0.75
threshold = min(max(threshold, 0.01), 0.99)
cfg = DetectionConfig(spam_threshold=threshold)
db.session.add(cfg)
db.session.commit()
return threshold
def init_admin(mysql_cfg: dict) -> str:
admin_cfg = mysql_cfg.get("admin_init", {})
if not admin_cfg.get("create_default_admin", True):
return "默认管理员创建已关闭"
username = admin_cfg.get("username", "admin")
password = admin_cfg.get("password", "Admin@123456")
nickname = admin_cfg.get("nickname", "系统管理员")
admin = User.query.filter_by(username=username).first()
if admin:
return f"管理员已存在: {username}"
admin = User(username=username, nickname=nickname, is_admin=True)
admin.set_password(password)
db.session.add(admin)
db.session.commit()
return f"管理员已创建: {username}"
def train_initial_model(model_path: str) -> dict:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
samples = [{"text": row.text, "label": row.label} for row in rows]
clf = NaiveBayesSpamClassifier(model_path)
return clf.train(samples)
def main():
mysql_cfg = load_mysql_cfg()
create_database(mysql_cfg)
app = create_app()
with app.app_context():
db.create_all()
migrations = run_sql_migrations(mysql_cfg)
created, updated = seed_samples()
threshold = ensure_detection_config(mysql_cfg)
admin_msg = init_admin(mysql_cfg)
model_meta = train_initial_model(app.config["NB_MODEL_PATH"])
print("数据库初始化完成")
print(f"- 样本新增: {created}")
print(f"- 样本更新: {updated}")
print(f"- 初始阈值: {threshold}")
print(f"- {admin_msg}")
print(f"- 模型版本: {model_meta.get('version')}")
if migrations:
print(f"- SQL迁移: {', '.join(migrations)}")
if __name__ == "__main__":
main()