Skip to content

Commit 8e98bbd

Browse files
authored
feat(be): improve abuse classifier model performance, implement input batch processing (#111)
* feat: add batch prediction function for classifier model * feat: update slang prediction endpoint to support batch input
1 parent 1410038 commit 8e98bbd

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

apps/classifier/app.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
PredictionRequest,
66
PredictionResponse,
77
)
8-
from model import predict
8+
from model import predict, predict_batch
99

1010
app = FastAPI()
1111

@@ -26,6 +26,8 @@ async def improve_reply_predict(data: PredictionRequest):
2626

2727
@app.post("/slang-predict", response_model=SlangPredictionResponse)
2828
async def slang_predict(data: SlangPredictionRequest):
29-
text = data.input
30-
predicted = predict(text, type="slang")
31-
return {"predicted": predicted[0], "probability": predicted[1]}
29+
text = data.inputs
30+
predicted = predict_batch(text, type="slang")
31+
return {
32+
"predictions": [{"predicted": p[0], "probability": p[1]} for p in predicted]
33+
}

apps/classifier/model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,33 @@ def predict(text: str, type: str) -> tuple[str, float]:
4949
predicted_probability = probabilities[0, predicted_label].item()
5050

5151
return inv_label_map[predicted_label], predicted_probability
52+
53+
54+
def predict_batch(texts: list[str], type: str) -> list[tuple[str, float]]:
55+
inputs = tokenizer(
56+
texts,
57+
return_tensors="pt",
58+
truncation=True,
59+
padding="max_length",
60+
max_length=512,
61+
)
62+
63+
model = models[type]
64+
inv_label_map = inv_label_maps[type]
65+
66+
with torch.no_grad():
67+
outputs = model(**inputs)
68+
69+
logits = outputs.logits
70+
probabilities = torch.softmax(logits, dim=-1)
71+
predicted_labels = torch.argmax(probabilities, dim=-1)
72+
predicted_probabilities = probabilities[
73+
torch.arange(probabilities.size(0)), predicted_labels
74+
]
75+
76+
return [
77+
(inv_label_map[label], prob)
78+
for label, prob in zip(
79+
predicted_labels.tolist(), predicted_probabilities.tolist()
80+
)
81+
]

apps/classifier/schemas.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
from pydantic import BaseModel
2+
from typing import List
23

34

45
class SlangPredictionRequest(BaseModel):
5-
input: str
6+
inputs: List[str]
67

78
class Config:
89
json_schema_extra = {
910
"example": {
10-
"input": "X같네",
11+
"inputs": ["X같네"],
1112
}
1213
}
1314

1415

15-
class SlangPredictionResponse(BaseModel):
16+
class SlangPredictionItem(BaseModel):
1617
predicted: str
1718
probability: float
1819

20+
21+
class SlangPredictionResponse(BaseModel):
22+
predictions: List[SlangPredictionItem]
23+
1924
class Config:
2025
json_schema_extra = {
21-
"example": {
22-
"predicted": "욕설",
23-
"probability": 0.99,
24-
}
26+
"example": {"predictions": [{"predicted": "욕설", "probability": 0.99}]}
2527
}
2628

2729

0 commit comments

Comments
 (0)