Oversample minority classes

This commit is contained in:
Tobias Eidelpes 2021-06-05 21:00:58 +02:00
parent ee039552d2
commit 2656724bb4
2 changed files with 6 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import pickle
from imblearn.over_sampling import SMOTE
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
@ -22,7 +23,10 @@ y = np.array(df['sublabel'])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0)
oversampler = SMOTE(random_state=777, k_neighbors=5)
X_train, y_train = oversampler.fit_resample(X_train, y_train)
clf = RandomForestClassifier(n_estimators=20, n_jobs=-1, criterion='gini', random_state=0)
clf.fit(X_train, y_train)
accuracy = clf.score(X_test, y_test)

View File

@ -13,3 +13,4 @@ scipy==1.6.3
six==1.16.0
sklearn==0.0
threadpoolctl==2.1.0
imbalanced-learn==0.8.0