In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(23) #kallisti

plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['figure.dpi'] = 150
sns.set()

SVD Demo¶

In [2]:
rectangle = pd.read_csv("rectangle_data.csv")
rectangle.tail(5)
Out[2]:
width height area perimeter
95 8 5 40 26
96 8 7 56 30
97 1 4 4 10
98 1 6 6 14
99 2 6 12 16

Singular value decomposition is a numerical technique to automatically decompose matrix into two matrices. Given an input matrix X, SVD will return $U\Sigma$ and $V^T$ such that $ X = U \Sigma V^T $.

In [3]:
u, s, vt = np.linalg.svd(rectangle, full_matrices = False)

The SVD routine returns $U$ and $\Sigma$ as two separate variables. To compute $U \Sigma$ we simply write:

In [4]:
usig = u * s

The two key pieces of the decomposition are $U\Sigma$ and $V^T$, which we can think of for now as analogous to our 'data' and 'transformation operation' from our manual decomposition earlier.

As we did before with our manual decomposition, we can recover our original rectangle data by multiplying the left matrix $U\Sigma$ by the right matrix $V^T$.

In [5]:
pd.DataFrame(usig @ vt).head(4)
Out[5]:
0 1 2 3
0 8.0 6.0 48.0 28.0
1 2.0 4.0 8.0 12.0
2 1.0 3.0 3.0 8.0
3 9.0 3.0 27.0 24.0
In [6]:
np.set_printoptions(suppress=True)
usig
Out[6]:
array([[-56.30926787,   4.08369641,  -0.76796869,   0.        ],
       [-13.92587137,  -5.61592446,   1.59106852,  -0.        ],
       [ -7.3883695 ,  -5.11089273,   1.51352951,   0.        ],
       [-36.84443159,  -4.80005945,  -3.80095908,  -0.        ],
       [-79.47260546,  13.00269827,   0.18659785,  -0.        ],
       [ -7.42135662,  -5.11810904,  -1.31469604,   0.        ],
       [-13.95885849,  -5.62314077,  -1.23715703,   0.        ],
       [-37.98955728,  -1.31360807,  -0.26071277,   0.        ],
       [-15.6692269 ,  -9.65347804,  -4.03555325,   0.        ],
       [-25.44680915,  -7.81311695,  -3.92620778,  -0.        ],
       [-32.68750933,  -2.52515864,   0.38769508,  -0.        ],
       [-53.89570114,   2.32104364,  -2.20593631,  -0.        ],
       [-40.87803851,  -1.86471027,  -2.34708823,  -0.        ],
       [ -5.34289549,  -3.98065864,   0.77963103,   0.        ],
       [-28.2033419 ,  -8.33535389,   5.30031897,   0.        ],
       [-64.00838899,   7.0615079 ,   1.43570386,  -0.        ],
       [-43.29160524,  -0.1020575 ,  -0.90912062,   0.        ],
       [-13.92587137,  -5.61592446,   1.59106852,  -0.        ],
       [-16.78136548,  -6.15981034,   2.33291861,   0.        ],
       [-18.43439279,  -4.99433025,  -0.47940371,   0.        ],
       [-17.73119447, -10.78732028,  -4.71576755,  -0.        ],
       [-48.59365319,   1.10949307,  -1.55752846,  -0.        ],
       [-79.47260546,  13.00269827,   0.18659785,  -0.        ],
       [-29.48041607,  -4.87776777,  -2.47233693,   0.        ],
       [-33.16242383,  -4.83891361,  -3.13664801,   0.        ],
       [-44.08513177,   0.48789886,   0.51294377,   0.        ],
       [-31.8939828 ,  -3.115115  ,  -1.03436931,   0.        ],
       [-40.87803851,  -1.86471027,  -2.34708823,  -0.        ],
       [-22.57482149,  -7.2656229 ,  -3.25394509,   0.        ],
       [-22.90992708,  -4.36551973,   0.27834961,  -0.        ],
       [ -9.48332419,  -6.25195129,  -1.99491034,   0.        ],
       [-70.94697069,   9.44214673,  -0.61091354,  -0.        ],
       [-13.95885849,  -5.62314077,  -1.23715703,   0.        ],
       [ -5.34289549,  -3.98065864,   0.77963103,   0.        ],
       [-17.73119447, -10.78732028,  -4.71576755,  -0.        ],
       [-27.40195494,  -3.74031736,  -0.37800985,  -0.        ],
       [-53.89570114,   2.32104364,  -2.20593631,  -0.        ],
       [-62.37185523,   5.89241965,   2.83391341,  -0.        ],
       [-29.41444183,  -4.86333514,   3.18411417,  -0.        ],
       [-29.48041607,  -4.87776777,  -2.47233693,   0.        ],
       [-57.1027944 ,   4.67365277,   0.6540957 ,  -0.        ],
       [-14.75238503,  -5.03318441,   0.18490736,   0.        ],
       [-32.68750933,  -2.52515864,   0.38769508,  -0.        ],
       [-29.48041607,  -4.87776777,  -2.47233693,   0.        ],
       [-25.44680915,  -7.81311695,  -3.92620778,  -0.        ],
       [-19.63685958,  -6.70369623,   3.0747687 ,   0.        ],
       [-15.6692269 ,  -9.65347804,  -4.03555325,   0.        ],
       [ -5.34289549,  -3.98065864,   0.77963103,   0.        ],
       [-22.08341343,  -4.94825977,   1.68451077,   0.        ],
       [-19.63685958,  -6.70369623,   3.0747687 ,   0.        ],
       [-37.97306372,  -1.30999991,   1.15340001,  -0.        ],
       [-19.79316204, -11.92116253,  -5.39598185,  -0.        ],
       [-25.79840831,  -4.91662193,  -1.80802586,   0.        ],
       [-19.63685958,  -6.70369623,   3.0747687 ,   0.        ],
       [-56.27628075,   4.09091272,   2.06025686,   0.        ],
       [-19.79316204, -11.92116253,  -5.39598185,  -0.        ],
       [-25.34784779,  -7.79146801,   4.55846888,   0.        ],
       [-29.41444183,  -4.86333514,   3.18411417,  -0.        ],
       [-37.98955728,  -1.31360807,  -0.26071277,   0.        ],
       [-56.30926787,   4.08369641,  -0.76796869,  -0.        ],
       [-28.31879682,  -8.36061099,  -4.59847046,  -0.        ],
       [-36.74547023,  -4.77841051,   4.68371758,  -0.        ],
       [-16.78136548,  -6.15981034,   2.33291861,   0.        ],
       [ -9.43384352,  -6.24112682,   2.24742798,   0.        ],
       [-14.75238503,  -5.03318441,   0.18490736,   0.        ],
       [-28.2033419 ,  -8.33535389,   5.30031897,   0.        ],
       [-31.8939828 ,  -3.115115  ,  -1.03436931,   0.        ],
       [-16.83084616,  -6.17063481,  -1.90941971,  -0.        ],
       [-19.63685958,  -6.70369623,   3.0747687 ,   0.        ],
       [-16.78136548,  -6.15981034,   2.33291861,   0.        ],
       [ -7.3883695 ,  -5.11089273,   1.51352951,   0.        ],
       [ -7.3883695 ,  -5.11089273,   1.51352951,   0.        ],
       [-50.18070626,   2.28940579,   1.28660032,   0.        ],
       [-56.30926787,   4.08369641,  -0.76796869,  -0.        ],
       [-22.49235369,  -7.24758212,   3.81661879,   0.        ],
       [-31.86099568,  -3.10789868,   1.79385624,  -0.        ],
       [ -3.29742148,  -2.85042454,   0.04573256,   0.        ],
       [-79.47260546,  13.00269827,   0.18659785,  -0.        ],
       [-32.68750933,  -2.52515864,   0.38769508,  -0.        ],
       [-36.74547023,  -4.77841051,   4.68371758,  -0.        ],
       [-25.34784779,  -7.79146801,   4.55846888,   0.        ],
       [ -8.21488316,  -4.52815268,   0.10736835,   0.        ],
       [-33.16242383,  -4.83891361,  -3.13664801,   0.        ],
       [-17.61573956, -10.76206319,   5.18302188,   0.        ],
       [-11.54529176,  -7.38579354,  -2.67512465,  -0.        ],
       [-43.25861812,  -0.09484119,   1.91910494,  -0.        ],
       [-53.89570114,   2.32104364,  -2.20593631,  -0.        ],
       [ -7.42135662,  -5.11810904,  -1.31469604,   0.        ],
       [-70.94697069,   9.44214673,  -0.61091354,  -0.        ],
       [-48.59365319,   1.10949307,  -1.55752846,  -0.        ],
       [-79.47260546,  13.00269827,   0.18659785,  -0.        ],
       [-27.38546138,  -3.73670921,   1.03610292,  -0.        ],
       [-27.38546138,  -3.73670921,   1.03610292,  -0.        ],
       [-11.47931753,  -7.37136091,   2.98132646,   0.        ],
       [-44.08513177,   0.48789886,   0.51294377,   0.        ],
       [-48.59365319,   1.10949307,  -1.55752846,  -0.        ],
       [-64.02488255,   7.05789975,   0.02159108,  -0.        ],
       [ -9.43384352,  -6.24112682,   2.24742798,   0.        ],
       [-13.52479154,  -8.501595  ,   3.71522493,  -0.        ],
       [-19.63685958,  -6.70369623,   3.0747687 ,  -0.        ]])

Naturally, we can instead use only the first 3 columns of usig and first 3 rows of vt and get back the exactly correct result. This si because the last column of usig is 0.

In [7]:
pd.DataFrame(usig[:, 0:3] @ vt[0:3, ]).head(4)
Out[7]:
0 1 2 3
0 8.0 6.0 48.0 28.0
1 2.0 4.0 8.0 12.0
2 1.0 3.0 3.0 8.0
3 9.0 3.0 27.0 24.0

If we use only the first 2 rows of usig and first 2 columns of vt, we end up with an imperfect reconstruction, but it's surprisingly not bad.

In [8]:
pd.DataFrame(usig[:, 0:2] @ vt[0:2, ]).tail(4)
Out[8]:
0 1 2 3
96 8.015221 6.984689 55.999828 29.999819
97 2.584341 2.406224 3.982129 9.981131
98 3.619075 3.365328 5.970458 13.968808
99 4.167581 3.819511 11.975551 15.974185

Even the one dimensional approximation is better than you might expect.

In [9]:
pd.DataFrame(usig[:, 0:1] @ vt[0:1, ]).tail(4)
Out[9]:
0 1 2 3
96 9.375531 8.319533 51.861441 35.390129
97 1.381452 1.225854 7.641603 5.214612
98 1.980513 1.757441 10.955353 7.475908
99 2.875538 2.551656 15.906251 10.854389

Rank 1 Approximation of 2D Data, Data Centering¶

In [10]:
# Downloads from https://www.gapminder.org/data/
cm_path = 'child_mortality_0_5_year_olds_dying_per_1000_born.csv'
fe_path = 'children_per_woman_total_fertility.csv'
cm = pd.read_csv(cm_path).set_index('country')['2017'].to_frame()/10
fe = pd.read_csv(fe_path).set_index('country')['2017'].to_frame()
child_data = cm.merge(fe, left_index=True, right_index=True).dropna()
child_data.columns = ['mortality', 'fertility']
child_data.head()
Out[10]:
mortality fertility
country
Afghanistan 6.820 4.48
Albania 1.330 1.71
Algeria 2.390 2.71
Angola 8.310 5.62
Antigua and Barbuda 0.816 2.04
In [11]:
def scatter14(data):
    sns.scatterplot('mortality', 'fertility', data=data)
    plt.xlim([0, 14])
    plt.ylim([0, 14])
    plt.xticks(np.arange(0, 14, 2))
    plt.yticks(np.arange(0, 14, 2))    
    
scatter14(child_data)
/srv/conda/envs/notebook/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
In [12]:
sns.scatterplot('mortality', 'fertility', data=child_data)
/srv/conda/envs/notebook/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
Out[12]:
<AxesSubplot:xlabel='mortality', ylabel='fertility'>
In [13]:
u, s, vt = np.linalg.svd(child_data, full_matrices = False)
In [14]:
child_data_reconstructed = pd.DataFrame(u @ np.diag(s) @ vt, columns = ["mortality", "fertility"], index=child_data.index)

As we'd expect, the product of $U$, $\Sigma$, and $V^T$ recovers the original data perfectly.

In [15]:
child_data_reconstructed.head(5)
Out[15]:
mortality fertility
country
Afghanistan 6.820 4.48
Albania 1.330 1.71
Algeria 2.390 2.71
Angola 8.310 5.62
Antigua and Barbuda 0.816 2.04
In [16]:
sns.scatterplot('mortality', 'fertility', data=child_data)
/srv/conda/envs/notebook/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
Out[16]:
<AxesSubplot:xlabel='mortality', ylabel='fertility'>

What happens if we throw away a column of $U$, a singular value from $\Sigma$, and a row from $V^T$? In this case we end up with the "rank 1 approximation" of the data.

Looking at the data, we see that it does a surprisingly good job.

In [17]:
#Rather than manually invoking linalg.svd over and over, let's just
#define a function that does the rank approximation in one function call
def compute_rank_k_approximation(data, k):
    u, s, vt = np.linalg.svd(data, full_matrices = False)
    return pd.DataFrame(u[:, 0:k] @ np.diag(s[0:k]) @ vt[0:k, :], columns = data.columns)
In [18]:
#child_data_rank_1_approximation = pd.DataFrame(u[:, :-1] @ np.diag(s[:-1]) @ vt[:-1, :], columns = ["mortality", "fertility"], index=child_data.index)

child_data_rank_1_approximation = compute_rank_k_approximation(child_data, 1)
child_data_rank_1_approximation.head(5)
Out[18]:
mortality fertility
0 6.694067 4.660869
1 1.697627 1.182004
2 2.880467 2.005579
3 8.232160 5.731795
4 1.506198 1.048719

By plotting the data in a 2D space, we can see what's going on. We're simply getting the original data projected on to some 1 dimensional subspace.

In [19]:
sns.scatterplot('mortality', 'fertility', data=child_data_rank_1_approximation)
/srv/conda/envs/notebook/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
Out[19]:
<AxesSubplot:xlabel='mortality', ylabel='fertility'>

There's one significant issue with our projection, which we can see by plotting both the original data and our reconstruction on the same axis. The issue is that the projection goes through the origin but our data has a non-zero y-intercept.

In [20]:
sns.scatterplot('mortality', 'fertility', data=child_data)
sns.scatterplot('mortality', 'fertility', data=child_data_rank_1_approximation)
/srv/conda/envs/notebook/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
/srv/conda/envs/notebook/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
Out[20]:
<AxesSubplot:xlabel='mortality', ylabel='fertility'>

While this y-intercept misalignment isn't terrible here, it can be really bad. For example, consider the 2D dataset below.

In [21]:
#http://jse.amstat.org/datasets/fat.txt
body_data = pd.read_fwf("fat.dat.txt", colspecs = [(9, 13), (17, 21), (23, 29), (35, 37),
                                             (39, 45), (48, 53), (57, 61), (64, 69),
                                             (73, 77), (80, 85), (88, 93), (96, 101),
                                             (105, 109), (113, 117), (121, 125), (129, 133),
                                             (137, 141), (145, 149)], 
                  
                  
                  header=None, names = ["% brozek fat", "% siri fat", "density", "age", 
                                       "weight", "height", "adiposity", "fat free weight",
                                       "neck", "chest", "abdomen", "hip", "thigh",
                                       "knee", "ankle", "bicep", "forearm",
                                       "wrist"])
#body_data = body_data.drop(41) #drop the weird record
body_data.head()
Out[21]:
% brozek fat % siri fat density age weight height adiposity fat free weight neck chest abdomen hip thigh knee ankle bicep forearm wrist
0 12.6 12.3 1.0708 23 154.25 67.75 23.7 134.9 36.2 93.1 85.2 94.5 59.0 37.3 21.9 32.0 27.4 17.1
1 6.9 6.1 1.0853 22 173.25 72.25 23.4 161.3 38.5 93.6 83.0 98.7 58.7 37.3 23.4 30.5 28.9 18.2
2 24.6 25.3 1.0414 22 154.00 66.25 24.7 116.0 34.0 95.8 87.9 99.2 59.6 38.9 24.0 28.8 25.2 16.6
3 10.9 10.4 1.0751 26 184.75 72.25 24.9 164.7 37.4 101.8 86.4 101.2 60.1 37.3 22.8 32.4 29.4 18.2
4 27.8 28.7 1.0340 24 184.25 71.25 25.6 133.1 34.4 97.3 100.0 101.9 63.2 42.2 24.0 32.2 27.7 17.7
In [22]:
density_and_abdomen = body_data[["density", "abdomen"]]
density_and_abdomen.head(5)
Out[22]:
density abdomen
0 1.0708 85.2
1 1.0853 83.0
2 1.0414 87.9
3 1.0751 86.4
4 1.0340 100.0

If we look at the data, the rank 1 approximation looks at least vaguely sane from the table.

In [23]:
density_and_abdomen_rank_1_approximation = compute_rank_k_approximation(density_and_abdomen, 1)
density_and_abdomen_rank_1_approximation.head(5)
Out[23]:
density abdomen
0 0.957134 85.201277
1 0.932425 83.001717
2 0.987458 87.900606
3 0.970613 86.401174
4 1.123369 99.998996

But if we plot on 2D axes, we'll see that things are very wrong.

In [24]:
sns.scatterplot(x="density", y="abdomen", data=body_data)
Out[24]:
<AxesSubplot:xlabel='density', ylabel='abdomen'>
In [25]:
density_and_abdomen_rank_1_approximation = compute_rank_k_approximation(density_and_abdomen, 1)
sns.scatterplot(x="density", y="abdomen", data=body_data)
sns.scatterplot(x="density", y="abdomen", data=density_and_abdomen_rank_1_approximation);

Since the subspace that we're projecting on to is off and to the right, we end up with a bizarre result where our rank 1 approximation believes that density increases with abdomen size, even though the data shows the opposite.

To fix this issue, we should always start the SVD process by zero-centering our data. That is, for each column, we should subtract the mean of that column.

In [26]:
np.mean(density_and_abdomen, axis = 0)
Out[26]:
density     1.055574
abdomen    92.555952
dtype: float64
In [27]:
density_and_abdomen_centered = density_and_abdomen - np.mean(density_and_abdomen, axis = 0)
density_and_abdomen_centered.head(5)
Out[27]:
density abdomen
0 0.015226 -7.355952
1 0.029726 -9.555952
2 -0.014174 -4.655952
3 0.019526 -6.155952
4 -0.021574 7.444048

Now when we do the approximation, things work much better.

In [28]:
density_and_abdomen_centered_rank_1_approximation = compute_rank_k_approximation(density_and_abdomen_centered, 1)
sns.scatterplot(x="density", y="abdomen", data=density_and_abdomen_centered)
sns.scatterplot(x="density", y="abdomen", data=density_and_abdomen_centered_rank_1_approximation);