Oversample minority classes
This commit is contained in:
parent
ee039552d2
commit
2656724bb4
@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pickle
|
import pickle
|
||||||
|
from imblearn.over_sampling import SMOTE
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
from sklearn.metrics import classification_report, confusion_matrix
|
from sklearn.metrics import classification_report, confusion_matrix
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
@ -22,7 +23,10 @@ y = np.array(df['sublabel'])
|
|||||||
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
|
||||||
|
|
||||||
clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0)
|
oversampler = SMOTE(random_state=777, k_neighbors=5)
|
||||||
|
X_train, y_train = oversampler.fit_resample(X_train, y_train)
|
||||||
|
|
||||||
|
clf = RandomForestClassifier(n_estimators=20, n_jobs=-1, criterion='gini', random_state=0)
|
||||||
clf.fit(X_train, y_train)
|
clf.fit(X_train, y_train)
|
||||||
|
|
||||||
accuracy = clf.score(X_test, y_test)
|
accuracy = clf.score(X_test, y_test)
|
||||||
|
|||||||
@ -13,3 +13,4 @@ scipy==1.6.3
|
|||||||
six==1.16.0
|
six==1.16.0
|
||||||
sklearn==0.0
|
sklearn==0.0
|
||||||
threadpoolctl==2.1.0
|
threadpoolctl==2.1.0
|
||||||
|
imbalanced-learn==0.8.0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user