Save model

This commit is contained in:
Tobias Eidelpes 2021-06-05 18:43:22 +02:00
parent 72554590fb
commit bd9d3b6932

View File

@ -1,7 +1,6 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import pickle
import seaborn as sn
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
@ -21,9 +20,9 @@ df.drop(['flowStartMilliseconds'], 1, inplace=True)
X = np.array(df.drop(columns=['sublabel'])) X = np.array(df.drop(columns=['sublabel']))
y = np.array(df['sublabel']) y = np.array(df['sublabel'])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0, class_weight="balanced") clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0)
clf.fit(X_train, y_train) clf.fit(X_train, y_train)
accuracy = clf.score(X_test, y_test) accuracy = clf.score(X_test, y_test)
@ -32,19 +31,11 @@ y_pred_train = clf.predict(X_train)
y_pred_test = clf.predict(X_test) y_pred_test = clf.predict(X_test)
print("\n *************** TRAINING ****************") print("\n *************** TRAINING ****************")
cm_train = confusion_matrix(y_train, y_pred_train) cm_train = confusion_matrix(y_train, y_pred_train)
plt.figure(figsize=(10, 7)) print(cm_train)
sn.heatmap(cm_train, annot=True)
plt.xlabel('Truth')
plt.ylabel('Predicted')
plt.show()
print(classification_report(y_train, y_pred_train)) print(classification_report(y_train, y_pred_train))
print("\n ************** VALIDATION ***************") print("\n ************** VALIDATION ***************")
cm_test = confusion_matrix(y_test, y_pred_test) cm_test = confusion_matrix(y_test, y_pred_test)
plt.figure(figsize=(10, 7)) print(cm_test)
sn.heatmap(cm_test, annot=True)
plt.xlabel('Truth')
plt.ylabel('Predicted')
plt.show()
print(classification_report(y_test, y_pred_test)) print(classification_report(y_test, y_pred_test))
example_measure = np.array([ip_to_bin('2.1.1.1'), ip_to_bin('2.1.1.2'), 0, 0, 1]) pickle.dump(clf, open('network_traffic_classifier.sav', 'wb'))