from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Activation
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers.legacy import SGD
from tensorflow.keras.callbacks import Callback
from sklearn.datasets import make_circles
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import io
from mlxtend.plotting import plot_decision_regions
2023-11-30 19:56:51.557177: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
First let's generate some data. We will be using SKLearn's make_circles to create a dataset that isn't linearly seperable.
num_points = 300
x, y = make_circles(n_samples=num_points, random_state=144, shuffle=True, factor=0.7, noise=0.01)
y
array([0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1])
Let's visualize the data. As we can see, it consists of two concentric circles where the color of each point represents which class it belongs to.
plt.figure(figsize=(5,5))
plt.scatter(x[:, 0], x[:, 1], c=y)
plt.show()
To ensure we aren't overfitting during training, let's split our x and y data into training and testing sets.
To visualize why we want to use a neural network for this problem, let's try to use SKLearn's Logistic Regression to classify the points:
def perform_logistic_regression(x, y):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=144)
clf = LogisticRegression(random_state=144).fit(x_train, y_train)
print('Training Score: ', clf.score(x_train, y_train))
print('Testing Score: ', clf.score(x_test, y_test))
plt.figure(figsize=(5,5))
plt.scatter(x[:, 0], x[:, 1], c=y)
plt.plot(x_test, clf.coef_ * x_test + clf.intercept_, linewidth=3)
plt.show()
perform_logistic_regression(x, y)
Training Score: 0.5083333333333333 Testing Score: 0.4666666666666667
It's apparent that our linear decision boundary won't cut it for this data. Let's try to build a simple neural network that can distinguish between these two classes:
# tanh(x) = 2 * sigmoid(2x) - 1
def build_model1():
input_layer = Input(shape=(2))
x = Dense(8, activation='tanh')(input_layer)
x = Dense(8, activation='tanh')(x)
x = Dense(1, activation='sigmoid')(x)
return Model(input_layer, x)
def build_model2():
model = Sequential()
model.add(Dense(units = 8, input_dim=2, activation='tanh'))
model.add(Dense(units = 8, activation='tanh'))
model.add(Dense(units = 1, activation='sigmoid'))
return model
model = build_model1()
model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 2)] 0 dense (Dense) (None, 8) 24 dense_1 (Dense) (None, 8) 72 dense_2 (Dense) (None, 1) 9 ================================================================= Total params: 105 Trainable params: 105 Non-trainable params: 0 _________________________________________________________________
Now we need to declare our optimizer and compile our model with our chosen loss, optimizer, and metrics.
sgd = SGD(learning_rate=0.05, decay=1e-6, momentum=0.9)
model.compile(loss='binary_crossentropy',
optimizer=sgd,
metrics=['accuracy'])
Before we train our model, let's use mlxtend to visualize the decision boundary. We can see that it initially cannot distinguish between the two classes.
plt.figure(figsize=(5,5))
plt.scatter(x[:, 0], x[:, 1])
plot_decision_regions(x, y, clf=model, legend=2)
plt.show()
7813/7813 [==============================] - 9s 1ms/step
For demonstration purposes, we create a callback to recreate the above plot every 10 epochs so we can visualize the progression of our decision boundary as we train our model.
class ImgCallback(Callback):
def on_epoch_begin(self, epoch, logs=None):
if epoch % 10 == 0:
fig = plt.figure(figsize=(5,5))
plt.scatter(x[:, 0], x[:, 1])
plot_decision_regions(x, y, clf=model, legend=2)
plt.show()
plt.close(fig)
Now let's train our model:
# Create callback object
img_callback = ImgCallback()
# Perform model fitting
history = model.fit(x,
y,
epochs=100,
batch_size=128,
validation_split=0.2,
callbacks=[img_callback])
# Printing our last decision boundary plot
fig = plt.figure(figsize=(5,5))
plt.scatter(x[:, 0], x[:, 1])
plot_decision_regions(x, y, clf=model, legend=2)
plt.show()
plt.close(fig)
7813/7813 [==============================] - 8s 1ms/step
Epoch 1/100 2/2 [==============================] - 1s 239ms/step - loss: 0.7216 - accuracy: 0.5083 - val_loss: 0.7423 - val_accuracy: 0.4333 Epoch 2/100 2/2 [==============================] - 0s 32ms/step - loss: 0.7157 - accuracy: 0.5125 - val_loss: 0.7355 - val_accuracy: 0.4167 Epoch 3/100 2/2 [==============================] - 0s 36ms/step - loss: 0.7066 - accuracy: 0.5250 - val_loss: 0.7270 - val_accuracy: 0.4167 Epoch 4/100 2/2 [==============================] - 0s 33ms/step - loss: 0.6983 - accuracy: 0.5417 - val_loss: 0.7190 - val_accuracy: 0.4333 Epoch 5/100 2/2 [==============================] - 0s 31ms/step - loss: 0.6909 - accuracy: 0.5167 - val_loss: 0.7141 - val_accuracy: 0.4000 Epoch 6/100 2/2 [==============================] - 0s 40ms/step - loss: 0.6875 - accuracy: 0.5583 - val_loss: 0.7089 - val_accuracy: 0.4167 Epoch 7/100 2/2 [==============================] - 0s 32ms/step - loss: 0.6864 - accuracy: 0.5208 - val_loss: 0.7055 - val_accuracy: 0.4167 Epoch 8/100 2/2 [==============================] - 0s 44ms/step - loss: 0.6864 - accuracy: 0.5208 - val_loss: 0.7025 - val_accuracy: 0.4167 Epoch 9/100 2/2 [==============================] - 0s 32ms/step - loss: 0.6869 - accuracy: 0.5208 - val_loss: 0.6997 - val_accuracy: 0.4833 Epoch 10/100 2/2 [==============================] - 0s 31ms/step - loss: 0.6870 - accuracy: 0.5958 - val_loss: 0.6973 - val_accuracy: 0.5833 7813/7813 [==============================] - 9s 1ms/step
Epoch 11/100 2/2 [==============================] - 0s 35ms/step - loss: 0.6863 - accuracy: 0.6625 - val_loss: 0.6948 - val_accuracy: 0.6500 Epoch 12/100 2/2 [==============================] - 0s 41ms/step - loss: 0.6849 - accuracy: 0.6875 - val_loss: 0.6935 - val_accuracy: 0.6333 Epoch 13/100 2/2 [==============================] - 0s 33ms/step - loss: 0.6834 - accuracy: 0.6917 - val_loss: 0.6922 - val_accuracy: 0.5667 Epoch 14/100 2/2 [==============================] - 0s 31ms/step - loss: 0.6814 - accuracy: 0.6500 - val_loss: 0.6919 - val_accuracy: 0.5167 Epoch 15/100 2/2 [==============================] - 0s 44ms/step - loss: 0.6791 - accuracy: 0.5750 - val_loss: 0.6939 - val_accuracy: 0.4833 Epoch 16/100 2/2 [==============================] - 0s 34ms/step - loss: 0.6770 - accuracy: 0.5583 - val_loss: 0.6960 - val_accuracy: 0.5000 Epoch 17/100 2/2 [==============================] - 0s 46ms/step - loss: 0.6758 - accuracy: 0.5750 - val_loss: 0.6986 - val_accuracy: 0.4667 Epoch 18/100 2/2 [==============================] - 0s 33ms/step - loss: 0.6742 - accuracy: 0.6000 - val_loss: 0.7000 - val_accuracy: 0.4667 Epoch 19/100 2/2 [==============================] - 0s 33ms/step - loss: 0.6734 - accuracy: 0.6083 - val_loss: 0.7001 - val_accuracy: 0.4833 Epoch 20/100 2/2 [==============================] - 0s 43ms/step - loss: 0.6723 - accuracy: 0.5875 - val_loss: 0.6992 - val_accuracy: 0.4500 7813/7813 [==============================] - 8s 1ms/step
Epoch 21/100 2/2 [==============================] - 0s 34ms/step - loss: 0.6705 - accuracy: 0.5875 - val_loss: 0.6978 - val_accuracy: 0.4667 Epoch 22/100 2/2 [==============================] - 0s 46ms/step - loss: 0.6686 - accuracy: 0.5917 - val_loss: 0.6951 - val_accuracy: 0.4667 Epoch 23/100 2/2 [==============================] - 0s 33ms/step - loss: 0.6670 - accuracy: 0.6042 - val_loss: 0.6916 - val_accuracy: 0.4667 Epoch 24/100 2/2 [==============================] - 0s 30ms/step - loss: 0.6650 - accuracy: 0.6167 - val_loss: 0.6897 - val_accuracy: 0.5000 Epoch 25/100 2/2 [==============================] - 0s 51ms/step - loss: 0.6631 - accuracy: 0.6208 - val_loss: 0.6864 - val_accuracy: 0.4833 Epoch 26/100 2/2 [==============================] - 0s 32ms/step - loss: 0.6612 - accuracy: 0.6250 - val_loss: 0.6841 - val_accuracy: 0.4667 Epoch 27/100 2/2 [==============================] - 0s 55ms/step - loss: 0.6591 - accuracy: 0.6208 - val_loss: 0.6820 - val_accuracy: 0.4667 Epoch 28/100 2/2 [==============================] - 0s 32ms/step - loss: 0.6572 - accuracy: 0.6208 - val_loss: 0.6792 - val_accuracy: 0.4667 Epoch 29/100 2/2 [==============================] - 0s 57ms/step - loss: 0.6554 - accuracy: 0.6208 - val_loss: 0.6774 - val_accuracy: 0.4667 Epoch 30/100 2/2 [==============================] - 0s 31ms/step - loss: 0.6534 - accuracy: 0.6083 - val_loss: 0.6766 - val_accuracy: 0.4667 7813/7813 [==============================] - 9s 1ms/step
Epoch 31/100 2/2 [==============================] - 0s 55ms/step - loss: 0.6511 - accuracy: 0.6083 - val_loss: 0.6754 - val_accuracy: 0.4667 Epoch 32/100 2/2 [==============================] - 0s 36ms/step - loss: 0.6487 - accuracy: 0.6083 - val_loss: 0.6735 - val_accuracy: 0.4667 Epoch 33/100 2/2 [==============================] - 0s 56ms/step - loss: 0.6461 - accuracy: 0.6208 - val_loss: 0.6719 - val_accuracy: 0.4833 Epoch 34/100 2/2 [==============================] - 0s 39ms/step - loss: 0.6434 - accuracy: 0.6292 - val_loss: 0.6706 - val_accuracy: 0.5000 Epoch 35/100 2/2 [==============================] - 0s 59ms/step - loss: 0.6406 - accuracy: 0.6375 - val_loss: 0.6690 - val_accuracy: 0.5000 Epoch 36/100 2/2 [==============================] - 0s 36ms/step - loss: 0.6380 - accuracy: 0.6458 - val_loss: 0.6670 - val_accuracy: 0.5167 Epoch 37/100 2/2 [==============================] - 0s 58ms/step - loss: 0.6350 - accuracy: 0.6458 - val_loss: 0.6655 - val_accuracy: 0.5333 Epoch 38/100 2/2 [==============================] - 0s 43ms/step - loss: 0.6323 - accuracy: 0.6458 - val_loss: 0.6654 - val_accuracy: 0.5167 Epoch 39/100 2/2 [==============================] - 0s 63ms/step - loss: 0.6288 - accuracy: 0.6500 - val_loss: 0.6631 - val_accuracy: 0.5333 Epoch 40/100 2/2 [==============================] - 0s 35ms/step - loss: 0.6257 - accuracy: 0.6542 - val_loss: 0.6586 - val_accuracy: 0.5500 7813/7813 [==============================] - 9s 1ms/step
Epoch 41/100 2/2 [==============================] - 0s 43ms/step - loss: 0.6222 - accuracy: 0.6542 - val_loss: 0.6556 - val_accuracy: 0.5500 Epoch 42/100 2/2 [==============================] - 0s 52ms/step - loss: 0.6191 - accuracy: 0.6542 - val_loss: 0.6532 - val_accuracy: 0.5500 Epoch 43/100 2/2 [==============================] - 0s 32ms/step - loss: 0.6150 - accuracy: 0.6583 - val_loss: 0.6495 - val_accuracy: 0.5833 Epoch 44/100 2/2 [==============================] - 0s 57ms/step - loss: 0.6114 - accuracy: 0.7125 - val_loss: 0.6431 - val_accuracy: 0.6333 Epoch 45/100 2/2 [==============================] - 0s 32ms/step - loss: 0.6071 - accuracy: 0.7542 - val_loss: 0.6390 - val_accuracy: 0.6500 Epoch 46/100 2/2 [==============================] - 0s 63ms/step - loss: 0.6028 - accuracy: 0.7875 - val_loss: 0.6353 - val_accuracy: 0.6667 Epoch 47/100 2/2 [==============================] - 0s 33ms/step - loss: 0.5985 - accuracy: 0.8083 - val_loss: 0.6316 - val_accuracy: 0.7000 Epoch 48/100 2/2 [==============================] - 0s 56ms/step - loss: 0.5940 - accuracy: 0.8042 - val_loss: 0.6280 - val_accuracy: 0.7000 Epoch 49/100 2/2 [==============================] - 0s 34ms/step - loss: 0.5893 - accuracy: 0.8125 - val_loss: 0.6227 - val_accuracy: 0.7500 Epoch 50/100 2/2 [==============================] - 0s 57ms/step - loss: 0.5842 - accuracy: 0.8167 - val_loss: 0.6190 - val_accuracy: 0.7333 7813/7813 [==============================] - 9s 1ms/step
Epoch 51/100 2/2 [==============================] - 0s 71ms/step - loss: 0.5791 - accuracy: 0.8000 - val_loss: 0.6151 - val_accuracy: 0.7000 Epoch 52/100 2/2 [==============================] - 0s 31ms/step - loss: 0.5736 - accuracy: 0.7917 - val_loss: 0.6105 - val_accuracy: 0.7000 Epoch 53/100 2/2 [==============================] - 0s 70ms/step - loss: 0.5681 - accuracy: 0.8083 - val_loss: 0.6044 - val_accuracy: 0.7333 Epoch 54/100 2/2 [==============================] - 0s 69ms/step - loss: 0.5621 - accuracy: 0.8250 - val_loss: 0.5985 - val_accuracy: 0.7500 Epoch 55/100 2/2 [==============================] - 0s 33ms/step - loss: 0.5560 - accuracy: 0.8375 - val_loss: 0.5924 - val_accuracy: 0.8000 Epoch 56/100 2/2 [==============================] - 0s 61ms/step - loss: 0.5496 - accuracy: 0.8458 - val_loss: 0.5849 - val_accuracy: 0.8333 Epoch 57/100 2/2 [==============================] - 0s 35ms/step - loss: 0.5427 - accuracy: 0.8625 - val_loss: 0.5759 - val_accuracy: 0.8500 Epoch 58/100 2/2 [==============================] - 0s 55ms/step - loss: 0.5362 - accuracy: 0.8958 - val_loss: 0.5696 - val_accuracy: 0.8667 Epoch 59/100 2/2 [==============================] - 0s 32ms/step - loss: 0.5283 - accuracy: 0.9125 - val_loss: 0.5585 - val_accuracy: 0.8833 Epoch 60/100 2/2 [==============================] - 0s 80ms/step - loss: 0.5205 - accuracy: 0.9333 - val_loss: 0.5511 - val_accuracy: 0.8833 7813/7813 [==============================] - 9s 1ms/step
Epoch 61/100 2/2 [==============================] - 0s 69ms/step - loss: 0.5134 - accuracy: 0.9333 - val_loss: 0.5439 - val_accuracy: 0.8833 Epoch 62/100 2/2 [==============================] - 0s 68ms/step - loss: 0.5039 - accuracy: 0.9375 - val_loss: 0.5319 - val_accuracy: 0.8833 Epoch 63/100 2/2 [==============================] - 0s 34ms/step - loss: 0.4954 - accuracy: 0.9458 - val_loss: 0.5227 - val_accuracy: 0.9000 Epoch 64/100 2/2 [==============================] - 0s 71ms/step - loss: 0.4860 - accuracy: 0.9542 - val_loss: 0.5115 - val_accuracy: 0.9000 Epoch 65/100 2/2 [==============================] - 0s 32ms/step - loss: 0.4765 - accuracy: 0.9583 - val_loss: 0.5009 - val_accuracy: 0.9167 Epoch 66/100 2/2 [==============================] - 0s 67ms/step - loss: 0.4673 - accuracy: 0.9667 - val_loss: 0.4882 - val_accuracy: 0.9333 Epoch 67/100 2/2 [==============================] - 0s 36ms/step - loss: 0.4568 - accuracy: 0.9708 - val_loss: 0.4783 - val_accuracy: 0.9333 Epoch 68/100 2/2 [==============================] - 0s 41ms/step - loss: 0.4465 - accuracy: 0.9667 - val_loss: 0.4685 - val_accuracy: 0.9333 Epoch 69/100 2/2 [==============================] - 0s 66ms/step - loss: 0.4390 - accuracy: 0.9542 - val_loss: 0.4639 - val_accuracy: 0.9167 Epoch 70/100 2/2 [==============================] - 0s 35ms/step - loss: 0.4262 - accuracy: 0.9542 - val_loss: 0.4455 - val_accuracy: 0.9333 7813/7813 [==============================] - 10s 1ms/step
Epoch 71/100 2/2 [==============================] - 0s 90ms/step - loss: 0.4137 - accuracy: 0.9708 - val_loss: 0.4289 - val_accuracy: 0.9333 Epoch 72/100 2/2 [==============================] - 0s 83ms/step - loss: 0.4024 - accuracy: 0.9875 - val_loss: 0.4154 - val_accuracy: 0.9667 Epoch 73/100 2/2 [==============================] - 0s 39ms/step - loss: 0.3909 - accuracy: 0.9958 - val_loss: 0.4052 - val_accuracy: 0.9500 Epoch 74/100 2/2 [==============================] - 0s 57ms/step - loss: 0.3789 - accuracy: 0.9958 - val_loss: 0.3934 - val_accuracy: 0.9667 Epoch 75/100 2/2 [==============================] - 0s 83ms/step - loss: 0.3672 - accuracy: 1.0000 - val_loss: 0.3830 - val_accuracy: 0.9500 Epoch 76/100 2/2 [==============================] - 0s 51ms/step - loss: 0.3557 - accuracy: 0.9958 - val_loss: 0.3700 - val_accuracy: 0.9500 Epoch 77/100 2/2 [==============================] - 0s 45ms/step - loss: 0.3430 - accuracy: 1.0000 - val_loss: 0.3538 - val_accuracy: 0.9833 Epoch 78/100 2/2 [==============================] - 0s 64ms/step - loss: 0.3337 - accuracy: 1.0000 - val_loss: 0.3387 - val_accuracy: 1.0000 Epoch 79/100 2/2 [==============================] - 0s 34ms/step - loss: 0.3194 - accuracy: 1.0000 - val_loss: 0.3341 - val_accuracy: 0.9667 Epoch 80/100 2/2 [==============================] - 0s 58ms/step - loss: 0.3082 - accuracy: 0.9917 - val_loss: 0.3268 - val_accuracy: 0.9500 7813/7813 [==============================] - 10s 1ms/step
Epoch 81/100 2/2 [==============================] - 0s 86ms/step - loss: 0.2962 - accuracy: 0.9917 - val_loss: 0.3114 - val_accuracy: 0.9833 Epoch 82/100 2/2 [==============================] - 0s 95ms/step - loss: 0.2845 - accuracy: 1.0000 - val_loss: 0.2922 - val_accuracy: 1.0000 Epoch 83/100 2/2 [==============================] - 0s 87ms/step - loss: 0.2720 - accuracy: 1.0000 - val_loss: 0.2815 - val_accuracy: 1.0000 Epoch 84/100 2/2 [==============================] - 0s 84ms/step - loss: 0.2615 - accuracy: 1.0000 - val_loss: 0.2699 - val_accuracy: 1.0000 Epoch 85/100 2/2 [==============================] - 0s 45ms/step - loss: 0.2503 - accuracy: 1.0000 - val_loss: 0.2655 - val_accuracy: 0.9833 Epoch 86/100 2/2 [==============================] - 0s 54ms/step - loss: 0.2385 - accuracy: 1.0000 - val_loss: 0.2523 - val_accuracy: 1.0000 Epoch 87/100 2/2 [==============================] - 0s 87ms/step - loss: 0.2281 - accuracy: 1.0000 - val_loss: 0.2376 - val_accuracy: 1.0000 Epoch 88/100 2/2 [==============================] - 0s 86ms/step - loss: 0.2179 - accuracy: 1.0000 - val_loss: 0.2250 - val_accuracy: 1.0000 Epoch 89/100 2/2 [==============================] - 0s 63ms/step - loss: 0.2074 - accuracy: 1.0000 - val_loss: 0.2175 - val_accuracy: 1.0000 Epoch 90/100 2/2 [==============================] - 0s 44ms/step - loss: 0.1976 - accuracy: 1.0000 - val_loss: 0.2094 - val_accuracy: 1.0000 7813/7813 [==============================] - 10s 1ms/step
Epoch 91/100 2/2 [==============================] - 0s 49ms/step - loss: 0.1886 - accuracy: 1.0000 - val_loss: 0.1984 - val_accuracy: 1.0000 Epoch 92/100 2/2 [==============================] - 0s 43ms/step - loss: 0.1793 - accuracy: 1.0000 - val_loss: 0.1893 - val_accuracy: 1.0000 Epoch 93/100 2/2 [==============================] - 0s 81ms/step - loss: 0.1707 - accuracy: 1.0000 - val_loss: 0.1791 - val_accuracy: 1.0000 Epoch 94/100 2/2 [==============================] - 0s 86ms/step - loss: 0.1625 - accuracy: 1.0000 - val_loss: 0.1703 - val_accuracy: 1.0000 Epoch 95/100 2/2 [==============================] - 0s 34ms/step - loss: 0.1549 - accuracy: 1.0000 - val_loss: 0.1614 - val_accuracy: 1.0000 Epoch 96/100 2/2 [==============================] - 0s 86ms/step - loss: 0.1476 - accuracy: 1.0000 - val_loss: 0.1538 - val_accuracy: 1.0000 Epoch 97/100 2/2 [==============================] - 0s 74ms/step - loss: 0.1405 - accuracy: 1.0000 - val_loss: 0.1461 - val_accuracy: 1.0000 Epoch 98/100 2/2 [==============================] - 0s 33ms/step - loss: 0.1335 - accuracy: 1.0000 - val_loss: 0.1400 - val_accuracy: 1.0000 Epoch 99/100 2/2 [==============================] - 0s 83ms/step - loss: 0.1273 - accuracy: 1.0000 - val_loss: 0.1337 - val_accuracy: 1.0000 Epoch 100/100 2/2 [==============================] - 0s 97ms/step - loss: 0.1216 - accuracy: 1.0000 - val_loss: 0.1257 - val_accuracy: 1.0000 7813/7813 [==============================] - 10s 1ms/step
Here we define some functions to plot the history data that the fit() function returns:
def plot_losses(hist):
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'])
plt.show()
def plot_accuracies(hist):
plt.plot(hist.history['accuracy'])
plt.plot(hist.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'])
plt.show()
We can see that our train and test losses decrease with a similar trend and ending at a similar value. This is a good indication that we are not overfitting on our testing data.
plot_losses(history)
We can also observe model performance via accuracy. Accuracy is simply defined as the number of datapoints classified correctly over the total number of datapoints.
plot_accuracies(history)