diff --git a/backend/app/ml/spam_categorizer.py b/backend/app/ml/spam_categorizer.py new file mode 100644 index 0000000..55b8dc2 --- /dev/null +++ b/backend/app/ml/spam_categorizer.py @@ -0,0 +1,68 @@ +"""垃圾信息分类标签模块 + +在朴素贝叶斯二分类(spam/ham)基础上,对判定为 spam 的文本进行细分类标签。 +分类优先级:诈骗 > 骚扰 > 广告(按危害程度排序) +""" + +CATEGORY_KEYWORDS = { + "fraud": [ + "中奖", "幸运粉丝", "幸运用户", "银行卡异常", "社保异常", "账号冻结", + "解封", "立即验证", "验证码", "欠费停机", "退款待确认", "违章信息", + "紧急通知", "账户异常", "风险", "核验", "被冻结", "将被冻结", + ], + "harassment": [ + "兼职", "日结", "高薪", "刷单", "赚钱", "外快", "宝妈", "学生都能做", + "添加微信", "扫码进群", "进群立刻", "想赚", "零花钱", "在家办公", + "无需面试", "火热招募", "秒赚", "招募", + ], + "advertisement": [ + "领取", "优惠", "红包", "优惠券", "秒杀", "返现", "补贴", "会员", + "特价", "低价", "点击链接", "扫码", "免费领取", "无门槛", "现金券", + "盲盒", "百分百中奖", "隐藏优惠券", "内部价", "货到付款", "限时", + "最后", "名额", "先到先得", + ], +} + +CATEGORY_LABELS = { + "fraud": "疑似诈骗", + "harassment": "疑似骚扰", + "advertisement": "疑似广告", + "spam": "疑似垃圾", + "ham": "", +} + +CATEGORY_PRIORITY = ["fraud", "harassment", "advertisement"] + + +def categorize_spam(text: str) -> tuple[str, str]: + """根据关键词匹配判定垃圾信息的具体分类标签 + + Args: + text: 待分类的文本内容 + + Returns: + tuple[str, str]: (category_code, category_label) + - category_code: fraud | harassment | advertisement | spam + - category_label: 疑似诈骗 | 疑似骚扰 | 疑似广告 | 疑似垃圾 + """ + text_lower = text.lower() + + for category in CATEGORY_PRIORITY: + keywords = CATEGORY_KEYWORDS.get(category, []) + for kw in keywords: + if kw.lower() in text_lower: + return category, CATEGORY_LABELS[category] + + return "spam", CATEGORY_LABELS["spam"] + + +def get_category_label(category: str) -> str: + """获取分类标签的中文显示文本 + + Args: + category: 分类代码 + + Returns: + str: 中文标签文本 + """ + return CATEGORY_LABELS.get(category, "") \ No newline at end of file diff --git a/backend/app/models.py b/backend/app/models.py index 5361142..79fa382 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -79,6 +79,7 @@ class SpamPredictionLog(db.Model): user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True) text = db.Column(db.Text, nullable=False) prediction = db.Column(db.String(16), nullable=False) # spam | ham + category = db.Column(db.String(32), default="") # fraud | harassment | advertisement | spam | 空 spam_probability = db.Column(db.Float, nullable=False) ham_probability = db.Column(db.Float, nullable=False) confidence = db.Column(db.Float, nullable=False) @@ -92,6 +93,7 @@ class SpamPredictionLog(db.Model): "user_id": self.user_id, "text": self.text, "prediction": self.prediction, + "category": self.category or "", "spam_probability": round(float(self.spam_probability), 4), "ham_probability": round(float(self.ham_probability), 4), "confidence": round(float(self.confidence), 4), @@ -130,6 +132,7 @@ class ContentPost(db.Model): status = db.Column(db.String(16), nullable=False, default="published") # published | blocked prediction = db.Column(db.String(16), nullable=False, default="ham") + category = db.Column(db.String(32), default="") # fraud | harassment | advertisement | spam | 空 spam_probability = db.Column(db.Float, nullable=False, default=0) ham_probability = db.Column(db.Float, nullable=False, default=0) confidence = db.Column(db.Float, nullable=False, default=0) @@ -163,6 +166,7 @@ class ContentPost(db.Model): "visibility": self.visibility, "status": self.status, "prediction": self.prediction, + "category": self.category or "", "spam_probability": round(float(self.spam_probability), 4), "ham_probability": round(float(self.ham_probability), 4), "confidence": round(float(self.confidence), 4), diff --git a/backend/app/routes/content_routes.py b/backend/app/routes/content_routes.py index 36ee323..a4f049a 100644 --- a/backend/app/routes/content_routes.py +++ b/backend/app/routes/content_routes.py @@ -5,6 +5,7 @@ from flask_jwt_extended import jwt_required from app.extensions import db from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier +from app.ml.spam_categorizer import categorize_spam, get_category_label from app.models import ContentPost, DetectionConfig, SpamPredictionLog, SpamTrainingSample, User from app.utils.auth import current_user from app.utils.response import fail, ok @@ -77,7 +78,7 @@ def _resolve_recipient(payload: dict, visibility: str, current_user_id: int): return recipient, None -def _predict_and_decide(text: str, user_credit: int = 100) -> tuple[dict, float, bool]: +def _predict_and_decide(text: str, user_credit: int = 100) -> tuple[dict, float, bool, str, str]: """根据用户信誉分调整阈值系数。信誉分越高,阈值越高(降低敏感度)""" clf = _ensure_ready() result = clf.predict(text) @@ -92,7 +93,14 @@ def _predict_and_decide(text: str, user_credit: int = 100) -> tuple[dict, float, adjusted_threshold = base_threshold * credit_factor blocked = float(result["spam_probability"]) >= adjusted_threshold - return result, adjusted_threshold, blocked + + # 分类标签 + category = "" + category_label = "" + if blocked: + category, category_label = categorize_spam(result["text"]) + + return result, adjusted_threshold, blocked, category, category_label @content_bp.post("/publish") @@ -113,7 +121,7 @@ def publish_text(): if err: return fail(err, 400) - result, threshold, blocked = _predict_and_decide(text, user.credit_score or 100) + result, threshold, blocked, category, category_label = _predict_and_decide(text, user.credit_score or 100) post = ContentPost( user_id=user.id, @@ -122,6 +130,7 @@ def publish_text(): visibility=visibility, status="blocked" if blocked else "published", prediction=result["prediction"], + category=category, spam_probability=result["spam_probability"], ham_probability=result["ham_probability"], confidence=result["confidence"], @@ -135,6 +144,7 @@ def publish_text(): user_id=user.id, text=result["text"], prediction=result["prediction"], + category=category, spam_probability=result["spam_probability"], ham_probability=result["ham_probability"], confidence=result["confidence"], @@ -153,14 +163,18 @@ def publish_text(): db.session.commit() - feedback = "发布成功" if not blocked else "疑似垃圾信息,系统已拦截,可提交申诉" + feedback = "发布成功" if not blocked else f"{category_label or '疑似垃圾信息'},系统已拦截,可提交申诉" return ok( { "publish_allowed": not blocked, "action": "published" if not blocked else "blocked", "feedback": feedback, "post": _serialize_post(post), - "detect": result, + "detect": { + **result, + "category": category, + "category_label": category_label, + }, }, feedback, ) @@ -188,13 +202,14 @@ def edit_post(post_id: int): if err: return fail(err, 400) - result, threshold, blocked = _predict_and_decide(text, user.credit_score or 100) + result, threshold, blocked, category, category_label = _predict_and_decide(text, user.credit_score or 100) post.text = result["text"] post.visibility = visibility post.recipient_user_id = recipient.id if recipient else None post.status = "blocked" if blocked else "published" post.prediction = result["prediction"] + post.category = category post.spam_probability = result["spam_probability"] post.ham_probability = result["ham_probability"] post.confidence = result["confidence"] diff --git a/backend/app/routes/spam_routes.py b/backend/app/routes/spam_routes.py index 9c3c90b..3d62970 100644 --- a/backend/app/routes/spam_routes.py +++ b/backend/app/routes/spam_routes.py @@ -3,6 +3,7 @@ from flask_jwt_extended import jwt_required from app.extensions import db from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier +from app.ml.spam_categorizer import categorize_spam, get_category_label from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample from app.utils.auth import admin_required, current_user from app.utils.response import fail, ok @@ -58,10 +59,17 @@ def predict_one(): threshold = _adjusted_threshold(user.credit_score or 100) blocked = float(result["spam_probability"]) >= threshold + # 分类标签:仅在判定为垃圾时进行细分 + category = "" + category_label = "" + if blocked: + category, category_label = categorize_spam(result["text"]) + row = SpamPredictionLog( user_id=user.id, text=result["text"], prediction=result["prediction"], + category=category, spam_probability=result["spam_probability"], ham_probability=result["ham_probability"], confidence=result["confidence"], @@ -71,7 +79,14 @@ def predict_one(): db.session.add(row) db.session.commit() - return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功") + return ok({ + **result, + "log_id": row.id, + "threshold": threshold, + "blocked_by_threshold": blocked, + "category": category, + "category_label": category_label, + }, "识别成功") @spam_bp.post("/predict/batch") @@ -98,12 +113,23 @@ def predict_batch(): if len(content) < 2: continue result = clf.predict(content) - result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold + blocked = float(result["spam_probability"]) >= threshold + result["blocked_by_threshold"] = blocked + + # 分类标签 + category = "" + category_label = "" + if blocked: + category, category_label = categorize_spam(result["text"]) + result["category"] = category + result["category_label"] = category_label + rows.append( SpamPredictionLog( user_id=user.id, text=result["text"], prediction=result["prediction"], + category=category, spam_probability=result["spam_probability"], ham_probability=result["ham_probability"], confidence=result["confidence"], diff --git a/backend/sql/add_category_field.sql b/backend/sql/add_category_field.sql new file mode 100644 index 0000000..c3f42e9 --- /dev/null +++ b/backend/sql/add_category_field.sql @@ -0,0 +1,5 @@ +-- 添加 category 字段到 spam_prediction_logs 表 +ALTER TABLE `spam_prediction_logs` ADD COLUMN `category` VARCHAR(32) DEFAULT '' AFTER `prediction`; + +-- 添加 category 字段到 content_posts 表 +ALTER TABLE `content_posts` ADD COLUMN `category` VARCHAR(32) DEFAULT '' AFTER `prediction`; \ No newline at end of file