1
This commit is contained in:
152
backend/init_db.py
Normal file
152
backend/init_db.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pymysql
|
||||
|
||||
from app import create_app
|
||||
from app.extensions import db
|
||||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||||
from app.models import DetectionConfig, SpamTrainingSample, User
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
|
||||
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
|
||||
|
||||
|
||||
def load_mysql_cfg() -> dict:
|
||||
if not MYSQL_CONFIG_PATH.exists():
|
||||
raise FileNotFoundError(f"未找到配置文件: {MYSQL_CONFIG_PATH}")
|
||||
with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file:
|
||||
return json.load(file)
|
||||
|
||||
|
||||
def create_database(mysql_cfg: dict) -> None:
|
||||
conn = pymysql.connect(
|
||||
host=mysql_cfg.get("host", "127.0.0.1"),
|
||||
port=int(mysql_cfg.get("port", 3306)),
|
||||
user=mysql_cfg["user"],
|
||||
password=mysql_cfg["password"],
|
||||
charset=mysql_cfg.get("charset", "utf8mb4"),
|
||||
autocommit=True,
|
||||
)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"CREATE DATABASE IF NOT EXISTS `{mysql_cfg['database']}` DEFAULT CHARACTER SET {mysql_cfg.get('charset', 'utf8mb4')}"
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def ensure_seed_file() -> None:
|
||||
if SPAM_SEED_PATH.exists():
|
||||
return
|
||||
SPAM_SEED_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
defaults = [
|
||||
{"text": "点击链接领取1000元现金红包,先到先得", "label": "spam"},
|
||||
{"text": "今晚项目例会改到19点,记得带周报", "label": "ham"},
|
||||
{"text": "恭喜你成为幸运用户,马上回复领取大奖", "label": "spam"},
|
||||
{"text": "明天上午10点客户演示,请提前准备材料", "label": "ham"},
|
||||
]
|
||||
with SPAM_SEED_PATH.open("w", encoding="utf-8") as file:
|
||||
json.dump(defaults, file, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def seed_samples() -> tuple[int, int]:
|
||||
ensure_seed_file()
|
||||
with SPAM_SEED_PATH.open("r", encoding="utf-8-sig") as file:
|
||||
rows = json.load(file)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
|
||||
for item in rows:
|
||||
text = " ".join((item.get("text") or "").strip().split())
|
||||
label = (item.get("label") or "").strip().lower()
|
||||
if len(text) < 2 or label not in {"spam", "ham"}:
|
||||
continue
|
||||
|
||||
sample = SpamTrainingSample.query.filter_by(text=text).first()
|
||||
if sample:
|
||||
sample.label = label
|
||||
sample.is_active = True
|
||||
sample.source = sample.source or "seed"
|
||||
updated += 1
|
||||
continue
|
||||
|
||||
sample = SpamTrainingSample(text=text, label=label, source="seed", is_active=True)
|
||||
db.session.add(sample)
|
||||
created += 1
|
||||
|
||||
db.session.commit()
|
||||
return created, updated
|
||||
|
||||
|
||||
def ensure_detection_config(mysql_cfg: dict) -> float:
|
||||
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
||||
if cfg:
|
||||
return float(cfg.spam_threshold)
|
||||
|
||||
init_cfg = mysql_cfg.get("detection_init", {}) if isinstance(mysql_cfg, dict) else {}
|
||||
try:
|
||||
threshold = float(init_cfg.get("spam_threshold", 0.75))
|
||||
except Exception:
|
||||
threshold = 0.75
|
||||
|
||||
threshold = min(max(threshold, 0.01), 0.99)
|
||||
cfg = DetectionConfig(spam_threshold=threshold)
|
||||
db.session.add(cfg)
|
||||
db.session.commit()
|
||||
return threshold
|
||||
|
||||
|
||||
def init_admin(mysql_cfg: dict) -> str:
|
||||
admin_cfg = mysql_cfg.get("admin_init", {})
|
||||
if not admin_cfg.get("create_default_admin", True):
|
||||
return "默认管理员创建已关闭"
|
||||
|
||||
username = admin_cfg.get("username", "admin")
|
||||
password = admin_cfg.get("password", "Admin@123456")
|
||||
nickname = admin_cfg.get("nickname", "系统管理员")
|
||||
|
||||
admin = User.query.filter_by(username=username).first()
|
||||
if admin:
|
||||
return f"管理员已存在: {username}"
|
||||
|
||||
admin = User(username=username, nickname=nickname, is_admin=True)
|
||||
admin.set_password(password)
|
||||
db.session.add(admin)
|
||||
db.session.commit()
|
||||
return f"管理员已创建: {username}"
|
||||
|
||||
|
||||
def train_initial_model(model_path: str) -> dict:
|
||||
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
||||
samples = [{"text": row.text, "label": row.label} for row in rows]
|
||||
clf = NaiveBayesSpamClassifier(model_path)
|
||||
return clf.train(samples)
|
||||
|
||||
|
||||
def main():
|
||||
mysql_cfg = load_mysql_cfg()
|
||||
create_database(mysql_cfg)
|
||||
|
||||
app = create_app()
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
created, updated = seed_samples()
|
||||
threshold = ensure_detection_config(mysql_cfg)
|
||||
admin_msg = init_admin(mysql_cfg)
|
||||
model_meta = train_initial_model(app.config["NB_MODEL_PATH"])
|
||||
|
||||
print("数据库初始化完成")
|
||||
print(f"- 样本新增: {created}")
|
||||
print(f"- 样本更新: {updated}")
|
||||
print(f"- 初始阈值: {threshold}")
|
||||
print(f"- {admin_msg}")
|
||||
print(f"- 模型版本: {model_meta.get('version')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user