From ee039552d281d926821e9e5ee7cca8249820b14c Mon Sep 17 00:00:00 2001 From: Tobias Eidelpes Date: Sat, 5 Jun 2021 18:51:24 +0200 Subject: [PATCH] Stratify test split (+0.25 points) --- competition/random_forest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/competition/random_forest.py b/competition/random_forest.py index 0540f0a..7e6e5ed 100644 --- a/competition/random_forest.py +++ b/competition/random_forest.py @@ -20,7 +20,7 @@ df.drop(['flowStartMilliseconds'], 1, inplace=True) X = np.array(df.drop(columns=['sublabel'])) y = np.array(df['sublabel']) -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y) clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0) clf.fit(X_train, y_train)