Oversample minority classes
This commit is contained in:
parent
ee039552d2
commit
2656724bb4
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle
|
||||
from imblearn.over_sampling import SMOTE
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
from sklearn.model_selection import train_test_split
|
||||
@ -22,7 +23,10 @@ y = np.array(df['sublabel'])
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
|
||||
|
||||
clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0)
|
||||
oversampler = SMOTE(random_state=777, k_neighbors=5)
|
||||
X_train, y_train = oversampler.fit_resample(X_train, y_train)
|
||||
|
||||
clf = RandomForestClassifier(n_estimators=20, n_jobs=-1, criterion='gini', random_state=0)
|
||||
clf.fit(X_train, y_train)
|
||||
|
||||
accuracy = clf.score(X_test, y_test)
|
||||
|
||||
@ -13,3 +13,4 @@ scipy==1.6.3
|
||||
six==1.16.0
|
||||
sklearn==0.0
|
||||
threadpoolctl==2.1.0
|
||||
imbalanced-learn==0.8.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user