import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(23) #kallisti
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['figure.dpi'] = 150
sns.set()
rectangle = pd.read_csv("rectangle_data.csv")
rectangle.tail(5)
width | height | area | perimeter | |
---|---|---|---|---|
95 | 8 | 5 | 40 | 26 |
96 | 8 | 7 | 56 | 30 |
97 | 1 | 4 | 4 | 10 |
98 | 1 | 6 | 6 | 14 |
99 | 2 | 6 | 12 | 16 |
Singular value decomposition is a numerical technique to automatically decompose matrix into two matrices. Given an input matrix X, SVD will return $U\Sigma$ and $V^T$ such that $ X = U \Sigma V^T $.
u, s, vt = np.linalg.svd(rectangle, full_matrices = False)
The SVD routine returns $U$ and $\Sigma$ as two separate variables. To compute $U \Sigma$ we simply write:
usig = u * s
The two key pieces of the decomposition are $U\Sigma$ and $V^T$, which we can think of for now as analogous to our 'data' and 'transformation operation' from our manual decomposition earlier.
As we did before with our manual decomposition, we can recover our original rectangle data by multiplying the left matrix $U\Sigma$ by the right matrix $V^T$.
pd.DataFrame(usig @ vt).head(4)
0 | 1 | 2 | 3 | |
---|---|---|---|---|
0 | 8.0 | 6.0 | 48.0 | 28.0 |
1 | 2.0 | 4.0 | 8.0 | 12.0 |
2 | 1.0 | 3.0 | 3.0 | 8.0 |
3 | 9.0 | 3.0 | 27.0 | 24.0 |
np.set_printoptions(suppress=True)
usig
array([[-56.30926787, 4.08369641, -0.76796869, 0. ], [-13.92587137, -5.61592446, 1.59106852, -0. ], [ -7.3883695 , -5.11089273, 1.51352951, 0. ], [-36.84443159, -4.80005945, -3.80095908, -0. ], [-79.47260546, 13.00269827, 0.18659785, -0. ], [ -7.42135662, -5.11810904, -1.31469604, 0. ], [-13.95885849, -5.62314077, -1.23715703, 0. ], [-37.98955728, -1.31360807, -0.26071277, 0. ], [-15.6692269 , -9.65347804, -4.03555325, 0. ], [-25.44680915, -7.81311695, -3.92620778, -0. ], [-32.68750933, -2.52515864, 0.38769508, -0. ], [-53.89570114, 2.32104364, -2.20593631, -0. ], [-40.87803851, -1.86471027, -2.34708823, -0. ], [ -5.34289549, -3.98065864, 0.77963103, 0. ], [-28.2033419 , -8.33535389, 5.30031897, 0. ], [-64.00838899, 7.0615079 , 1.43570386, -0. ], [-43.29160524, -0.1020575 , -0.90912062, 0. ], [-13.92587137, -5.61592446, 1.59106852, -0. ], [-16.78136548, -6.15981034, 2.33291861, 0. ], [-18.43439279, -4.99433025, -0.47940371, 0. ], [-17.73119447, -10.78732028, -4.71576755, -0. ], [-48.59365319, 1.10949307, -1.55752846, -0. ], [-79.47260546, 13.00269827, 0.18659785, -0. ], [-29.48041607, -4.87776777, -2.47233693, 0. ], [-33.16242383, -4.83891361, -3.13664801, 0. ], [-44.08513177, 0.48789886, 0.51294377, 0. ], [-31.8939828 , -3.115115 , -1.03436931, 0. ], [-40.87803851, -1.86471027, -2.34708823, -0. ], [-22.57482149, -7.2656229 , -3.25394509, 0. ], [-22.90992708, -4.36551973, 0.27834961, -0. ], [ -9.48332419, -6.25195129, -1.99491034, 0. ], [-70.94697069, 9.44214673, -0.61091354, -0. ], [-13.95885849, -5.62314077, -1.23715703, 0. ], [ -5.34289549, -3.98065864, 0.77963103, 0. ], [-17.73119447, -10.78732028, -4.71576755, -0. ], [-27.40195494, -3.74031736, -0.37800985, -0. ], [-53.89570114, 2.32104364, -2.20593631, -0. ], [-62.37185523, 5.89241965, 2.83391341, -0. ], [-29.41444183, -4.86333514, 3.18411417, -0. ], [-29.48041607, -4.87776777, -2.47233693, 0. ], [-57.1027944 , 4.67365277, 0.6540957 , -0. ], [-14.75238503, -5.03318441, 0.18490736, 0. ], [-32.68750933, -2.52515864, 0.38769508, -0. ], [-29.48041607, -4.87776777, -2.47233693, 0. ], [-25.44680915, -7.81311695, -3.92620778, -0. ], [-19.63685958, -6.70369623, 3.0747687 , 0. ], [-15.6692269 , -9.65347804, -4.03555325, 0. ], [ -5.34289549, -3.98065864, 0.77963103, 0. ], [-22.08341343, -4.94825977, 1.68451077, 0. ], [-19.63685958, -6.70369623, 3.0747687 , 0. ], [-37.97306372, -1.30999991, 1.15340001, -0. ], [-19.79316204, -11.92116253, -5.39598185, -0. ], [-25.79840831, -4.91662193, -1.80802586, 0. ], [-19.63685958, -6.70369623, 3.0747687 , 0. ], [-56.27628075, 4.09091272, 2.06025686, 0. ], [-19.79316204, -11.92116253, -5.39598185, -0. ], [-25.34784779, -7.79146801, 4.55846888, 0. ], [-29.41444183, -4.86333514, 3.18411417, -0. ], [-37.98955728, -1.31360807, -0.26071277, 0. ], [-56.30926787, 4.08369641, -0.76796869, -0. ], [-28.31879682, -8.36061099, -4.59847046, -0. ], [-36.74547023, -4.77841051, 4.68371758, -0. ], [-16.78136548, -6.15981034, 2.33291861, 0. ], [ -9.43384352, -6.24112682, 2.24742798, 0. ], [-14.75238503, -5.03318441, 0.18490736, 0. ], [-28.2033419 , -8.33535389, 5.30031897, 0. ], [-31.8939828 , -3.115115 , -1.03436931, 0. ], [-16.83084616, -6.17063481, -1.90941971, -0. ], [-19.63685958, -6.70369623, 3.0747687 , 0. ], [-16.78136548, -6.15981034, 2.33291861, 0. ], [ -7.3883695 , -5.11089273, 1.51352951, 0. ], [ -7.3883695 , -5.11089273, 1.51352951, 0. ], [-50.18070626, 2.28940579, 1.28660032, 0. ], [-56.30926787, 4.08369641, -0.76796869, -0. ], [-22.49235369, -7.24758212, 3.81661879, 0. ], [-31.86099568, -3.10789868, 1.79385624, -0. ], [ -3.29742148, -2.85042454, 0.04573256, 0. ], [-79.47260546, 13.00269827, 0.18659785, -0. ], [-32.68750933, -2.52515864, 0.38769508, -0. ], [-36.74547023, -4.77841051, 4.68371758, -0. ], [-25.34784779, -7.79146801, 4.55846888, 0. ], [ -8.21488316, -4.52815268, 0.10736835, 0. ], [-33.16242383, -4.83891361, -3.13664801, 0. ], [-17.61573956, -10.76206319, 5.18302188, 0. ], [-11.54529176, -7.38579354, -2.67512465, -0. ], [-43.25861812, -0.09484119, 1.91910494, -0. ], [-53.89570114, 2.32104364, -2.20593631, -0. ], [ -7.42135662, -5.11810904, -1.31469604, 0. ], [-70.94697069, 9.44214673, -0.61091354, -0. ], [-48.59365319, 1.10949307, -1.55752846, -0. ], [-79.47260546, 13.00269827, 0.18659785, -0. ], [-27.38546138, -3.73670921, 1.03610292, -0. ], [-27.38546138, -3.73670921, 1.03610292, -0. ], [-11.47931753, -7.37136091, 2.98132646, 0. ], [-44.08513177, 0.48789886, 0.51294377, 0. ], [-48.59365319, 1.10949307, -1.55752846, -0. ], [-64.02488255, 7.05789975, 0.02159108, -0. ], [ -9.43384352, -6.24112682, 2.24742798, 0. ], [-13.52479154, -8.501595 , 3.71522493, -0. ], [-19.63685958, -6.70369623, 3.0747687 , -0. ]])
Naturally, we can instead use only the first 3 columns of usig and first 3 rows of vt and get back the exactly correct result. This si because the last column of usig is 0.
pd.DataFrame(usig[:, 0:3] @ vt[0:3, ]).head(4)
0 | 1 | 2 | 3 | |
---|---|---|---|---|
0 | 8.0 | 6.0 | 48.0 | 28.0 |
1 | 2.0 | 4.0 | 8.0 | 12.0 |
2 | 1.0 | 3.0 | 3.0 | 8.0 |
3 | 9.0 | 3.0 | 27.0 | 24.0 |
If we use only the first 2 rows of usig and first 2 columns of vt, we end up with an imperfect reconstruction, but it's surprisingly not bad.
pd.DataFrame(usig[:, 0:2] @ vt[0:2, ]).tail(4)
0 | 1 | 2 | 3 | |
---|---|---|---|---|
96 | 8.015221 | 6.984689 | 55.999828 | 29.999819 |
97 | 2.584341 | 2.406224 | 3.982129 | 9.981131 |
98 | 3.619075 | 3.365328 | 5.970458 | 13.968808 |
99 | 4.167581 | 3.819511 | 11.975551 | 15.974185 |
Even the one dimensional approximation is better than you might expect.
pd.DataFrame(usig[:, 0:1] @ vt[0:1, ]).tail(4)
0 | 1 | 2 | 3 | |
---|---|---|---|---|
96 | 9.375531 | 8.319533 | 51.861441 | 35.390129 |
97 | 1.381452 | 1.225854 | 7.641603 | 5.214612 |
98 | 1.980513 | 1.757441 | 10.955353 | 7.475908 |
99 | 2.875538 | 2.551656 | 15.906251 | 10.854389 |
# Downloads from https://www.gapminder.org/data/
cm_path = 'child_mortality_0_5_year_olds_dying_per_1000_born.csv'
fe_path = 'children_per_woman_total_fertility.csv'
cm = pd.read_csv(cm_path).set_index('country')['2017'].to_frame()/10
fe = pd.read_csv(fe_path).set_index('country')['2017'].to_frame()
child_data = cm.merge(fe, left_index=True, right_index=True).dropna()
child_data.columns = ['mortality', 'fertility']
child_data.head()
mortality | fertility | |
---|---|---|
country | ||
Afghanistan | 6.820 | 4.48 |
Albania | 1.330 | 1.71 |
Algeria | 2.390 | 2.71 |
Angola | 8.310 | 5.62 |
Antigua and Barbuda | 0.816 | 2.04 |
def scatter14(data):
sns.scatterplot('mortality', 'fertility', data=data)
plt.xlim([0, 14])
plt.ylim([0, 14])
plt.xticks(np.arange(0, 14, 2))
plt.yticks(np.arange(0, 14, 2))
scatter14(child_data)
/opt/conda/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
sns.scatterplot('mortality', 'fertility', data=child_data)
/opt/conda/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
<AxesSubplot:xlabel='mortality', ylabel='fertility'>
u, s, vt = np.linalg.svd(child_data, full_matrices = False)
child_data_reconstructed = pd.DataFrame(u @ np.diag(s) @ vt, columns = ["mortality", "fertility"], index=child_data.index)
As we'd expect, the product of $U$, $\Sigma$, and $V^T$ recovers the original data perfectly.
child_data_reconstructed.head(5)
mortality | fertility | |
---|---|---|
country | ||
Afghanistan | 6.820 | 4.48 |
Albania | 1.330 | 1.71 |
Algeria | 2.390 | 2.71 |
Angola | 8.310 | 5.62 |
Antigua and Barbuda | 0.816 | 2.04 |
sns.scatterplot('mortality', 'fertility', data=child_data)
/opt/conda/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
<AxesSubplot:xlabel='mortality', ylabel='fertility'>
What happens if we throw away a column of $U$, a singular value from $\Sigma$, and a row from $V^T$? In this case we end up with the "rank 1 approximation" of the data.
Looking at the data, we see that it does a surprisingly good job.
#Rather than manually invoking linalg.svd over and over, let's just
#define a function that does the rank approximation in one function call
def compute_rank_k_approximation(data, k):
u, s, vt = np.linalg.svd(data, full_matrices = False)
return pd.DataFrame(u[:, 0:k] @ np.diag(s[0:k]) @ vt[0:k, :], columns = data.columns)
#child_data_rank_1_approximation = pd.DataFrame(u[:, :-1] @ np.diag(s[:-1]) @ vt[:-1, :], columns = ["mortality", "fertility"], index=child_data.index)
child_data_rank_1_approximation = compute_rank_k_approximation(child_data, 1)
child_data_rank_1_approximation.head(5)
mortality | fertility | |
---|---|---|
0 | 6.694067 | 4.660869 |
1 | 1.697627 | 1.182004 |
2 | 2.880467 | 2.005579 |
3 | 8.232160 | 5.731795 |
4 | 1.506198 | 1.048719 |
By plotting the data in a 2D space, we can see what's going on. We're simply getting the original data projected on to some 1 dimensional subspace.
sns.scatterplot('mortality', 'fertility', data=child_data_rank_1_approximation)
/opt/conda/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
<AxesSubplot:xlabel='mortality', ylabel='fertility'>
There's one significant issue with our projection, which we can see by plotting both the original data and our reconstruction on the same axis. The issue is that the projection goes through the origin but our data has a non-zero y-intercept.
sns.scatterplot('mortality', 'fertility', data=child_data)
sns.scatterplot('mortality', 'fertility', data=child_data_rank_1_approximation)
/opt/conda/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn( /opt/conda/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
<AxesSubplot:xlabel='mortality', ylabel='fertility'>
While this y-intercept misalignment isn't terrible here, it can be really bad. For example, consider the 2D dataset below.
#http://jse.amstat.org/datasets/fat.txt
body_data = pd.read_fwf("fat.dat.txt", colspecs = [(9, 13), (17, 21), (23, 29), (35, 37),
(39, 45), (48, 53), (57, 61), (64, 69),
(73, 77), (80, 85), (88, 93), (96, 101),
(105, 109), (113, 117), (121, 125), (129, 133),
(137, 141), (145, 149)],
header=None, names = ["% brozek fat", "% siri fat", "density", "age",
"weight", "height", "adiposity", "fat free weight",
"neck", "chest", "abdomen", "hip", "thigh",
"knee", "ankle", "bicep", "forearm",
"wrist"])
#body_data = body_data.drop(41) #drop the weird record
body_data.head()
% brozek fat | % siri fat | density | age | weight | height | adiposity | fat free weight | neck | chest | abdomen | hip | thigh | knee | ankle | bicep | forearm | wrist | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 12.6 | 12.3 | 1.0708 | 23 | 154.25 | 67.75 | 23.7 | 134.9 | 36.2 | 93.1 | 85.2 | 94.5 | 59.0 | 37.3 | 21.9 | 32.0 | 27.4 | 17.1 |
1 | 6.9 | 6.1 | 1.0853 | 22 | 173.25 | 72.25 | 23.4 | 161.3 | 38.5 | 93.6 | 83.0 | 98.7 | 58.7 | 37.3 | 23.4 | 30.5 | 28.9 | 18.2 |
2 | 24.6 | 25.3 | 1.0414 | 22 | 154.00 | 66.25 | 24.7 | 116.0 | 34.0 | 95.8 | 87.9 | 99.2 | 59.6 | 38.9 | 24.0 | 28.8 | 25.2 | 16.6 |
3 | 10.9 | 10.4 | 1.0751 | 26 | 184.75 | 72.25 | 24.9 | 164.7 | 37.4 | 101.8 | 86.4 | 101.2 | 60.1 | 37.3 | 22.8 | 32.4 | 29.4 | 18.2 |
4 | 27.8 | 28.7 | 1.0340 | 24 | 184.25 | 71.25 | 25.6 | 133.1 | 34.4 | 97.3 | 100.0 | 101.9 | 63.2 | 42.2 | 24.0 | 32.2 | 27.7 | 17.7 |
density_and_abdomen = body_data[["density", "abdomen"]]
density_and_abdomen.head(5)
density | abdomen | |
---|---|---|
0 | 1.0708 | 85.2 |
1 | 1.0853 | 83.0 |
2 | 1.0414 | 87.9 |
3 | 1.0751 | 86.4 |
4 | 1.0340 | 100.0 |
If we look at the data, the rank 1 approximation looks at least vaguely sane from the table.
density_and_abdomen_rank_1_approximation = compute_rank_k_approximation(density_and_abdomen, 1)
density_and_abdomen_rank_1_approximation.head(5)
density | abdomen | |
---|---|---|
0 | 0.957134 | 85.201277 |
1 | 0.932425 | 83.001717 |
2 | 0.987458 | 87.900606 |
3 | 0.970613 | 86.401174 |
4 | 1.123369 | 99.998996 |
But if we plot on 2D axes, we'll see that things are very wrong.
sns.scatterplot(x="density", y="abdomen", data=body_data)
<AxesSubplot:xlabel='density', ylabel='abdomen'>
density_and_abdomen_rank_1_approximation = compute_rank_k_approximation(density_and_abdomen, 1)
sns.scatterplot(x="density", y="abdomen", data=body_data)
sns.scatterplot(x="density", y="abdomen", data=density_and_abdomen_rank_1_approximation);
Since the subspace that we're projecting on to is off and to the right, we end up with a bizarre result where our rank 1 approximation believes that density increases with abdomen size, even though the data shows the opposite.
To fix this issue, we should always start the SVD process by zero-centering our data. That is, for each column, we should subtract the mean of that column.
np.mean(density_and_abdomen, axis = 0)
density 1.055574 abdomen 92.555952 dtype: float64
density_and_abdomen_centered = density_and_abdomen - np.mean(density_and_abdomen, axis = 0)
density_and_abdomen_centered.head(5)
density | abdomen | |
---|---|---|
0 | 0.015226 | -7.355952 |
1 | 0.029726 | -9.555952 |
2 | -0.014174 | -4.655952 |
3 | 0.019526 | -6.155952 |
4 | -0.021574 | 7.444048 |
Now when we do the approximation, things work much better.
density_and_abdomen_centered_rank_1_approximation = compute_rank_k_approximation(density_and_abdomen_centered, 1)
sns.scatterplot(x="density", y="abdomen", data=density_and_abdomen_centered)
sns.scatterplot(x="density", y="abdomen", data=density_and_abdomen_centered_rank_1_approximation);