import json from pathlib import Path import pymysql from app import create_app from app.extensions import db from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier from app.models import DetectionConfig, SpamTrainingSample, User BASE_DIR = Path(__file__).resolve().parent MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json" SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json" def load_mysql_cfg() -> dict: if not MYSQL_CONFIG_PATH.exists(): raise FileNotFoundError(f"未找到配置文件: {MYSQL_CONFIG_PATH}") with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file: return json.load(file) def create_database(mysql_cfg: dict) -> None: conn = pymysql.connect( host=mysql_cfg.get("host", "127.0.0.1"), port=int(mysql_cfg.get("port", 3306)), user=mysql_cfg["user"], password=mysql_cfg["password"], charset=mysql_cfg.get("charset", "utf8mb4"), autocommit=True, ) try: with conn.cursor() as cursor: cursor.execute( f"CREATE DATABASE IF NOT EXISTS `{mysql_cfg['database']}` DEFAULT CHARACTER SET {mysql_cfg.get('charset', 'utf8mb4')}" ) finally: conn.close() def ensure_seed_file() -> None: if SPAM_SEED_PATH.exists(): return SPAM_SEED_PATH.parent.mkdir(parents=True, exist_ok=True) defaults = [ {"text": "点击链接领取1000元现金红包,先到先得", "label": "spam"}, {"text": "今晚项目例会改到19点,记得带周报", "label": "ham"}, {"text": "恭喜你成为幸运用户,马上回复领取大奖", "label": "spam"}, {"text": "明天上午10点客户演示,请提前准备材料", "label": "ham"}, ] with SPAM_SEED_PATH.open("w", encoding="utf-8") as file: json.dump(defaults, file, ensure_ascii=False, indent=2) def seed_samples() -> tuple[int, int]: ensure_seed_file() with SPAM_SEED_PATH.open("r", encoding="utf-8-sig") as file: rows = json.load(file) created = 0 updated = 0 for item in rows: text = " ".join((item.get("text") or "").strip().split()) label = (item.get("label") or "").strip().lower() if len(text) < 2 or label not in {"spam", "ham"}: continue sample = SpamTrainingSample.query.filter_by(text=text).first() if sample: sample.label = label sample.is_active = True sample.source = sample.source or "seed" updated += 1 continue sample = SpamTrainingSample(text=text, label=label, source="seed", is_active=True) db.session.add(sample) created += 1 db.session.commit() return created, updated def ensure_detection_config(mysql_cfg: dict) -> float: cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first() if cfg: return float(cfg.spam_threshold) init_cfg = mysql_cfg.get("detection_init", {}) if isinstance(mysql_cfg, dict) else {} try: threshold = float(init_cfg.get("spam_threshold", 0.75)) except Exception: threshold = 0.75 threshold = min(max(threshold, 0.01), 0.99) cfg = DetectionConfig(spam_threshold=threshold) db.session.add(cfg) db.session.commit() return threshold def init_admin(mysql_cfg: dict) -> str: admin_cfg = mysql_cfg.get("admin_init", {}) if not admin_cfg.get("create_default_admin", True): return "默认管理员创建已关闭" username = admin_cfg.get("username", "admin") password = admin_cfg.get("password", "Admin@123456") nickname = admin_cfg.get("nickname", "系统管理员") admin = User.query.filter_by(username=username).first() if admin: return f"管理员已存在: {username}" admin = User(username=username, nickname=nickname, is_admin=True) admin.set_password(password) db.session.add(admin) db.session.commit() return f"管理员已创建: {username}" def train_initial_model(model_path: str) -> dict: rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all() samples = [{"text": row.text, "label": row.label} for row in rows] clf = NaiveBayesSpamClassifier(model_path) return clf.train(samples) def main(): mysql_cfg = load_mysql_cfg() create_database(mysql_cfg) app = create_app() with app.app_context(): db.create_all() created, updated = seed_samples() threshold = ensure_detection_config(mysql_cfg) admin_msg = init_admin(mysql_cfg) model_meta = train_initial_model(app.config["NB_MODEL_PATH"]) print("数据库初始化完成") print(f"- 样本新增: {created}") print(f"- 样本更新: {updated}") print(f"- 初始阈值: {threshold}") print(f"- {admin_msg}") print(f"- 模型版本: {model_meta.get('version')}") if __name__ == "__main__": main()