This commit is contained in:
刘正航
2026-04-21 22:45:19 +08:00
commit b5237f9038
159 changed files with 7769 additions and 0 deletions

152
backend/init_db.py Normal file
View File

@@ -0,0 +1,152 @@
import json
from pathlib import Path
import pymysql
from app import create_app
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import DetectionConfig, SpamTrainingSample, User
BASE_DIR = Path(__file__).resolve().parent
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
def load_mysql_cfg() -> dict:
if not MYSQL_CONFIG_PATH.exists():
raise FileNotFoundError(f"未找到配置文件: {MYSQL_CONFIG_PATH}")
with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file:
return json.load(file)
def create_database(mysql_cfg: dict) -> None:
conn = pymysql.connect(
host=mysql_cfg.get("host", "127.0.0.1"),
port=int(mysql_cfg.get("port", 3306)),
user=mysql_cfg["user"],
password=mysql_cfg["password"],
charset=mysql_cfg.get("charset", "utf8mb4"),
autocommit=True,
)
try:
with conn.cursor() as cursor:
cursor.execute(
f"CREATE DATABASE IF NOT EXISTS `{mysql_cfg['database']}` DEFAULT CHARACTER SET {mysql_cfg.get('charset', 'utf8mb4')}"
)
finally:
conn.close()
def ensure_seed_file() -> None:
if SPAM_SEED_PATH.exists():
return
SPAM_SEED_PATH.parent.mkdir(parents=True, exist_ok=True)
defaults = [
{"text": "点击链接领取1000元现金红包先到先得", "label": "spam"},
{"text": "今晚项目例会改到19点记得带周报", "label": "ham"},
{"text": "恭喜你成为幸运用户,马上回复领取大奖", "label": "spam"},
{"text": "明天上午10点客户演示请提前准备材料", "label": "ham"},
]
with SPAM_SEED_PATH.open("w", encoding="utf-8") as file:
json.dump(defaults, file, ensure_ascii=False, indent=2)
def seed_samples() -> tuple[int, int]:
ensure_seed_file()
with SPAM_SEED_PATH.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
created = 0
updated = 0
for item in rows:
text = " ".join((item.get("text") or "").strip().split())
label = (item.get("label") or "").strip().lower()
if len(text) < 2 or label not in {"spam", "ham"}:
continue
sample = SpamTrainingSample.query.filter_by(text=text).first()
if sample:
sample.label = label
sample.is_active = True
sample.source = sample.source or "seed"
updated += 1
continue
sample = SpamTrainingSample(text=text, label=label, source="seed", is_active=True)
db.session.add(sample)
created += 1
db.session.commit()
return created, updated
def ensure_detection_config(mysql_cfg: dict) -> float:
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
if cfg:
return float(cfg.spam_threshold)
init_cfg = mysql_cfg.get("detection_init", {}) if isinstance(mysql_cfg, dict) else {}
try:
threshold = float(init_cfg.get("spam_threshold", 0.75))
except Exception:
threshold = 0.75
threshold = min(max(threshold, 0.01), 0.99)
cfg = DetectionConfig(spam_threshold=threshold)
db.session.add(cfg)
db.session.commit()
return threshold
def init_admin(mysql_cfg: dict) -> str:
admin_cfg = mysql_cfg.get("admin_init", {})
if not admin_cfg.get("create_default_admin", True):
return "默认管理员创建已关闭"
username = admin_cfg.get("username", "admin")
password = admin_cfg.get("password", "Admin@123456")
nickname = admin_cfg.get("nickname", "系统管理员")
admin = User.query.filter_by(username=username).first()
if admin:
return f"管理员已存在: {username}"
admin = User(username=username, nickname=nickname, is_admin=True)
admin.set_password(password)
db.session.add(admin)
db.session.commit()
return f"管理员已创建: {username}"
def train_initial_model(model_path: str) -> dict:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
samples = [{"text": row.text, "label": row.label} for row in rows]
clf = NaiveBayesSpamClassifier(model_path)
return clf.train(samples)
def main():
mysql_cfg = load_mysql_cfg()
create_database(mysql_cfg)
app = create_app()
with app.app_context():
db.create_all()
created, updated = seed_samples()
threshold = ensure_detection_config(mysql_cfg)
admin_msg = init_admin(mysql_cfg)
model_meta = train_initial_model(app.config["NB_MODEL_PATH"])
print("数据库初始化完成")
print(f"- 样本新增: {created}")
print(f"- 样本更新: {updated}")
print(f"- 初始阈值: {threshold}")
print(f"- {admin_msg}")
print(f"- 模型版本: {model_meta.get('version')}")
if __name__ == "__main__":
main()