153 lines
4.9 KiB
Python
153 lines
4.9 KiB
Python
import json
|
||
from pathlib import Path
|
||
|
||
import pymysql
|
||
|
||
from app import create_app
|
||
from app.extensions import db
|
||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||
from app.models import DetectionConfig, SpamTrainingSample, User
|
||
|
||
|
||
BASE_DIR = Path(__file__).resolve().parent
|
||
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
|
||
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
|
||
|
||
|
||
def load_mysql_cfg() -> dict:
|
||
if not MYSQL_CONFIG_PATH.exists():
|
||
raise FileNotFoundError(f"未找到配置文件: {MYSQL_CONFIG_PATH}")
|
||
with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file:
|
||
return json.load(file)
|
||
|
||
|
||
def create_database(mysql_cfg: dict) -> None:
|
||
conn = pymysql.connect(
|
||
host=mysql_cfg.get("host", "127.0.0.1"),
|
||
port=int(mysql_cfg.get("port", 3306)),
|
||
user=mysql_cfg["user"],
|
||
password=mysql_cfg["password"],
|
||
charset=mysql_cfg.get("charset", "utf8mb4"),
|
||
autocommit=True,
|
||
)
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
f"CREATE DATABASE IF NOT EXISTS `{mysql_cfg['database']}` DEFAULT CHARACTER SET {mysql_cfg.get('charset', 'utf8mb4')}"
|
||
)
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def ensure_seed_file() -> None:
|
||
if SPAM_SEED_PATH.exists():
|
||
return
|
||
SPAM_SEED_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||
defaults = [
|
||
{"text": "点击链接领取1000元现金红包,先到先得", "label": "spam"},
|
||
{"text": "今晚项目例会改到19点,记得带周报", "label": "ham"},
|
||
{"text": "恭喜你成为幸运用户,马上回复领取大奖", "label": "spam"},
|
||
{"text": "明天上午10点客户演示,请提前准备材料", "label": "ham"},
|
||
]
|
||
with SPAM_SEED_PATH.open("w", encoding="utf-8") as file:
|
||
json.dump(defaults, file, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def seed_samples() -> tuple[int, int]:
|
||
ensure_seed_file()
|
||
with SPAM_SEED_PATH.open("r", encoding="utf-8-sig") as file:
|
||
rows = json.load(file)
|
||
|
||
created = 0
|
||
updated = 0
|
||
|
||
for item in rows:
|
||
text = " ".join((item.get("text") or "").strip().split())
|
||
label = (item.get("label") or "").strip().lower()
|
||
if len(text) < 2 or label not in {"spam", "ham"}:
|
||
continue
|
||
|
||
sample = SpamTrainingSample.query.filter_by(text=text).first()
|
||
if sample:
|
||
sample.label = label
|
||
sample.is_active = True
|
||
sample.source = sample.source or "seed"
|
||
updated += 1
|
||
continue
|
||
|
||
sample = SpamTrainingSample(text=text, label=label, source="seed", is_active=True)
|
||
db.session.add(sample)
|
||
created += 1
|
||
|
||
db.session.commit()
|
||
return created, updated
|
||
|
||
|
||
def ensure_detection_config(mysql_cfg: dict) -> float:
|
||
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
||
if cfg:
|
||
return float(cfg.spam_threshold)
|
||
|
||
init_cfg = mysql_cfg.get("detection_init", {}) if isinstance(mysql_cfg, dict) else {}
|
||
try:
|
||
threshold = float(init_cfg.get("spam_threshold", 0.75))
|
||
except Exception:
|
||
threshold = 0.75
|
||
|
||
threshold = min(max(threshold, 0.01), 0.99)
|
||
cfg = DetectionConfig(spam_threshold=threshold)
|
||
db.session.add(cfg)
|
||
db.session.commit()
|
||
return threshold
|
||
|
||
|
||
def init_admin(mysql_cfg: dict) -> str:
|
||
admin_cfg = mysql_cfg.get("admin_init", {})
|
||
if not admin_cfg.get("create_default_admin", True):
|
||
return "默认管理员创建已关闭"
|
||
|
||
username = admin_cfg.get("username", "admin")
|
||
password = admin_cfg.get("password", "Admin@123456")
|
||
nickname = admin_cfg.get("nickname", "系统管理员")
|
||
|
||
admin = User.query.filter_by(username=username).first()
|
||
if admin:
|
||
return f"管理员已存在: {username}"
|
||
|
||
admin = User(username=username, nickname=nickname, is_admin=True)
|
||
admin.set_password(password)
|
||
db.session.add(admin)
|
||
db.session.commit()
|
||
return f"管理员已创建: {username}"
|
||
|
||
|
||
def train_initial_model(model_path: str) -> dict:
|
||
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
||
samples = [{"text": row.text, "label": row.label} for row in rows]
|
||
clf = NaiveBayesSpamClassifier(model_path)
|
||
return clf.train(samples)
|
||
|
||
|
||
def main():
|
||
mysql_cfg = load_mysql_cfg()
|
||
create_database(mysql_cfg)
|
||
|
||
app = create_app()
|
||
with app.app_context():
|
||
db.create_all()
|
||
created, updated = seed_samples()
|
||
threshold = ensure_detection_config(mysql_cfg)
|
||
admin_msg = init_admin(mysql_cfg)
|
||
model_meta = train_initial_model(app.config["NB_MODEL_PATH"])
|
||
|
||
print("数据库初始化完成")
|
||
print(f"- 样本新增: {created}")
|
||
print(f"- 样本更新: {updated}")
|
||
print(f"- 初始阈值: {threshold}")
|
||
print(f"- {admin_msg}")
|
||
print(f"- 模型版本: {model_meta.get('version')}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|