diff --git a/competition/random_forest.py b/competition/random_forest.py index 7e6e5ed..ed51684 100644 --- a/competition/random_forest.py +++ b/competition/random_forest.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pickle +from imblearn.over_sampling import SMOTE from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import classification_report, confusion_matrix from sklearn.model_selection import train_test_split @@ -22,7 +23,10 @@ y = np.array(df['sublabel']) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y) -clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0) +oversampler = SMOTE(random_state=777, k_neighbors=5) +X_train, y_train = oversampler.fit_resample(X_train, y_train) + +clf = RandomForestClassifier(n_estimators=20, n_jobs=-1, criterion='gini', random_state=0) clf.fit(X_train, y_train) accuracy = clf.score(X_test, y_test) diff --git a/competition/requirements.txt b/competition/requirements.txt index 352f10a..7062b5c 100644 --- a/competition/requirements.txt +++ b/competition/requirements.txt @@ -13,3 +13,4 @@ scipy==1.6.3 six==1.16.0 sklearn==0.0 threadpoolctl==2.1.0 +imbalanced-learn==0.8.0