Save model
This commit is contained in:
parent
72554590fb
commit
bd9d3b6932
@ -1,7 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import pickle
|
||||||
import seaborn as sn
|
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
from sklearn.metrics import classification_report, confusion_matrix
|
from sklearn.metrics import classification_report, confusion_matrix
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
@ -21,9 +20,9 @@ df.drop(['flowStartMilliseconds'], 1, inplace=True)
|
|||||||
X = np.array(df.drop(columns=['sublabel']))
|
X = np.array(df.drop(columns=['sublabel']))
|
||||||
y = np.array(df['sublabel'])
|
y = np.array(df['sublabel'])
|
||||||
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True)
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
|
||||||
|
|
||||||
clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0, class_weight="balanced")
|
clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, criterion='gini', random_state=0)
|
||||||
clf.fit(X_train, y_train)
|
clf.fit(X_train, y_train)
|
||||||
|
|
||||||
accuracy = clf.score(X_test, y_test)
|
accuracy = clf.score(X_test, y_test)
|
||||||
@ -32,19 +31,11 @@ y_pred_train = clf.predict(X_train)
|
|||||||
y_pred_test = clf.predict(X_test)
|
y_pred_test = clf.predict(X_test)
|
||||||
print("\n *************** TRAINING ****************")
|
print("\n *************** TRAINING ****************")
|
||||||
cm_train = confusion_matrix(y_train, y_pred_train)
|
cm_train = confusion_matrix(y_train, y_pred_train)
|
||||||
plt.figure(figsize=(10, 7))
|
print(cm_train)
|
||||||
sn.heatmap(cm_train, annot=True)
|
|
||||||
plt.xlabel('Truth')
|
|
||||||
plt.ylabel('Predicted')
|
|
||||||
plt.show()
|
|
||||||
print(classification_report(y_train, y_pred_train))
|
print(classification_report(y_train, y_pred_train))
|
||||||
print("\n ************** VALIDATION ***************")
|
print("\n ************** VALIDATION ***************")
|
||||||
cm_test = confusion_matrix(y_test, y_pred_test)
|
cm_test = confusion_matrix(y_test, y_pred_test)
|
||||||
plt.figure(figsize=(10, 7))
|
print(cm_test)
|
||||||
sn.heatmap(cm_test, annot=True)
|
|
||||||
plt.xlabel('Truth')
|
|
||||||
plt.ylabel('Predicted')
|
|
||||||
plt.show()
|
|
||||||
print(classification_report(y_test, y_pred_test))
|
print(classification_report(y_test, y_pred_test))
|
||||||
|
|
||||||
example_measure = np.array([ip_to_bin('2.1.1.1'), ip_to_bin('2.1.1.2'), 0, 0, 1])
|
pickle.dump(clf, open('network_traffic_classifier.sav', 'wb'))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user