1
This commit is contained in:
17
backend/train_model.py
Normal file
17
backend/train_model.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from app import create_app
|
||||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||||
from app.models import SpamTrainingSample
|
||||
|
||||
|
||||
def main():
|
||||
app = create_app()
|
||||
with app.app_context():
|
||||
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
||||
samples = [{"text": row.text, "label": row.label} for row in rows]
|
||||
clf = NaiveBayesSpamClassifier(app.config["NB_MODEL_PATH"])
|
||||
meta = clf.train(samples)
|
||||
print(f"模型训练完成: {meta.get('version')} 样本数={meta.get('sample_count')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user