Open In Colab

4. Chapter 05 - Multivariate Linear Models

!pip install torch torchvision pyro-ppl proplot black blackcellmagic
Requirement already satisfied: torch in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (1.7.1)
Requirement already satisfied: torchvision in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (0.8.2)
Requirement already satisfied: pyro-ppl in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (1.3.1)
Requirement already satisfied: proplot in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (0.6.4)
Requirement already satisfied: black in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (20.8b1)
Requirement already satisfied: blackcellmagic in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (0.0.2)
Requirement already satisfied: typing-extensions>=3.7.4 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (3.7.4.3)
Requirement already satisfied: toml>=0.10.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (0.10.2)
Requirement already satisfied: pathspec<1,>=0.6 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (0.8.1)
Requirement already satisfied: typed-ast>=1.4.0 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (1.4.2)
Requirement already satisfied: regex>=2020.1.8 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (2020.11.13)
Requirement already satisfied: appdirs in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (1.4.4)
Requirement already satisfied: mypy-extensions>=0.4.3 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (0.4.3)
Requirement already satisfied: click>=7.1.2 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (7.1.2)
Requirement already satisfied: ipython in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from blackcellmagic) (7.19.0)
Requirement already satisfied: matplotlib in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from proplot) (3.3.3)
Requirement already satisfied: pyro-api>=0.1.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (0.1.2)
Requirement already satisfied: opt-einsum>=2.3.2 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (3.3.0)
Requirement already satisfied: tqdm>=4.36 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (4.56.0)
Requirement already satisfied: numpy>=1.7 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (1.19.5)
Requirement already satisfied: pillow>=4.1.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from torchvision) (8.1.0)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (3.0.10)
Requirement already satisfied: jedi>=0.10 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (0.18.0)
Requirement already satisfied: decorator in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (4.4.2)
Requirement already satisfied: pexpect>4.3 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (4.8.0)
Requirement already satisfied: traitlets>=4.2 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (5.0.5)
Requirement already satisfied: pickleshare in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (0.7.5)
Requirement already satisfied: setuptools>=18.5 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (47.1.0)
Requirement already satisfied: backcall in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (0.2.0)
Requirement already satisfied: pygments in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (2.7.4)
Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from jedi>=0.10->ipython->blackcellmagic) (0.8.1)
Requirement already satisfied: ptyprocess>=0.5 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pexpect>4.3->ipython->blackcellmagic) (0.7.0)
Requirement already satisfied: wcwidth in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->blackcellmagic) (0.2.5)
Requirement already satisfied: ipython-genutils in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from traitlets>=4.2->ipython->blackcellmagic) (0.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (2.4.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (0.10.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (2.8.1)
Requirement already satisfied: six in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from cycler>=0.10->matplotlib->proplot) (1.15.0)
%load_ext blackcellmagic
import warnings

import pandas as pd
import proplot as plot
import pyro
import pyro.distributions as dist
import pyro.infer
import pyro.ops.stats as stats
import pyro.optim
import seaborn as sns
import torch
import torch.distributions.constraints as constraints
import torch.tensor as tensor
from pyro.contrib.autoguide import AutoLaplaceApproximation

warnings.filterwarnings("ignore")
%pylab inline
pyro.set_rng_seed(42)

plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["font.weight"] = "bold"

def summary(samples, prob=0.95):
    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0).data.numpy(),
            "std": torch.std(v, 0).data.numpy(),
            "{:.1f}%".format(100 * (1 - prob)): v.kthvalue(
                int(len(v) * (1 - prob)), dim=0
            )[0].data.numpy(),
            "{:.1f}%".format(100 * prob): v.kthvalue(int(len(v) * prob), dim=0)[
                0
            ].data.numpy(),
        }
    return pd.DataFrame(site_stats)
Populating the interactive namespace from numpy and matplotlib
waffle_divorce = pd.read_csv("https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv", sep=";")
d = waffle_divorce
# standardize predictor

d['MedianAgeMarriage_s'] = (d.MedianAgeMarriage - d.MedianAgeMarriage.mean())/d.MedianAgeMarriage.std()

def model51(MedianAgeMarriage_s):
  
  a = pyro.sample("a", dist.Normal(tensor(10.), tensor(10.)))
  bA = pyro.sample("bA", dist.Normal(tensor(0.), tensor(1.)))
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(50.)))
  mu = a + bA * MedianAgeMarriage_s

  divorce = pyro.sample("divorce", dist.Normal(mu, sigma))
  return divorce 

def link(laplace_guide, data, num_samples):
  pred = pyro.infer.Predictive(laplace_guide, num_samples=num_samples)
  samples = pred.get_samples()
  mu = samples["bA"].detach().reshape((num_samples,1))*data.reshape(1, data.shape[0]) + samples["a"].detach().reshape((num_samples,1))
  return mu

divorce = tensor(d.Divorce, dtype=torch.float)
MedianAgeMarriage_s = tensor(d['MedianAgeMarriage_s'], dtype=torch.float)

conditioned51 = pyro.condition(model51, data={"divorce": divorce})
guide51 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned51)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned51,
    guide=guide51,
    optim=pyro.optim.Adam({"lr": 0.1}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(MedianAgeMarriage_s) for t in range(num_steps)]
plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")
plt.tight_layout()
laplace_guide51 = guide51.laplace_approximation(MedianAgeMarriage_s)
pred51 = pyro.infer.Predictive(laplace_guide51, num_samples=1000)  
../_images/05_Chapter05_4_0.png

4.1. Code 5.2

!pip install torch torchvision pyro-ppl proplot black blackcellmagic
Requirement already satisfied: torch in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (1.7.1)
Requirement already satisfied: torchvision in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (0.8.2)
Requirement already satisfied: pyro-ppl in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (1.3.1)
Requirement already satisfied: proplot in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (0.6.4)
Requirement already satisfied: black in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (20.8b1)
Requirement already satisfied: blackcellmagic in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (0.0.2)
Requirement already satisfied: typing-extensions>=3.7.4 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (3.7.4.3)
Requirement already satisfied: pathspec<1,>=0.6 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (0.8.1)
Requirement already satisfied: regex>=2020.1.8 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (2020.11.13)
Requirement already satisfied: appdirs in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (1.4.4)
Requirement already satisfied: click>=7.1.2 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (7.1.2)
Requirement already satisfied: typed-ast>=1.4.0 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (1.4.2)
Requirement already satisfied: toml>=0.10.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (0.10.2)
Requirement already satisfied: mypy-extensions>=0.4.3 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from black) (0.4.3)
Requirement already satisfied: ipython in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from blackcellmagic) (7.19.0)
Requirement already satisfied: matplotlib in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from proplot) (3.3.3)
Requirement already satisfied: pyro-api>=0.1.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (0.1.2)
Requirement already satisfied: opt-einsum>=2.3.2 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (3.3.0)
Requirement already satisfied: tqdm>=4.36 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (4.56.0)
Requirement already satisfied: numpy>=1.7 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pyro-ppl) (1.19.5)
Requirement already satisfied: pillow>=4.1.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from torchvision) (8.1.0)
Requirement already satisfied: setuptools>=18.5 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (47.1.0)
Requirement already satisfied: pickleshare in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (0.7.5)
Requirement already satisfied: pygments in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (2.7.4)
Requirement already satisfied: pexpect>4.3 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (4.8.0)
Requirement already satisfied: decorator in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (4.4.2)
Requirement already satisfied: traitlets>=4.2 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (5.0.5)
Requirement already satisfied: jedi>=0.10 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (0.18.0)
Requirement already satisfied: backcall in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (0.2.0)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from ipython->blackcellmagic) (3.0.10)
Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from jedi>=0.10->ipython->blackcellmagic) (0.8.1)
Requirement already satisfied: ptyprocess>=0.5 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from pexpect>4.3->ipython->blackcellmagic) (0.7.0)
Requirement already satisfied: wcwidth in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->blackcellmagic) (0.2.5)
Requirement already satisfied: ipython-genutils in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from traitlets>=4.2->ipython->blackcellmagic) (0.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (1.3.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (2.4.7)
Requirement already satisfied: python-dateutil>=2.1 in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from matplotlib->proplot) (2.8.1)
Requirement already satisfied: six in /opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages (from cycler>=0.10->matplotlib->proplot) (1.15.0)
mam_seq = torch.linspace(-3., 3.5, steps=30)
mu = link(laplace_guide51, mam_seq, 1000)
mu_PI = stats.pi(mu, dim=0, prob=0.89)

fig, ax = plt.subplots()
ax.scatter(waffle_divorce['MedianAgeMarriage_s'], waffle_divorce['Divorce'])

x = waffle_divorce['MedianAgeMarriage_s'].sort_values()
y = tensor(waffle_divorce['MedianAgeMarriage_s'].sort_values()) * pred51.get_samples()["bA"].detach().mean() + pred51.get_samples()["a"].detach().mean()
ax.scatter(waffle_divorce['MedianAgeMarriage_s'], waffle_divorce['Divorce'])
ax.plot(x,y)
ax.fill_between(mam_seq, mu_PI[0], mu_PI[1], color='gray')
ax.set_xlabel("MedianAgeMarriage.s")
ax.set_ylabel('Divorce')
fig.tight_layout()
../_images/05_Chapter05_7_0.png

4.2. Code 5.3

d['Marriage_s'] = (d.Marriage - d.Marriage.mean())/d.Marriage.std()

def model52(Marriage_s):
  
  a = pyro.sample("a", dist.Normal(tensor(10.), tensor(10.)))
  bR = pyro.sample("bR", dist.Normal(tensor(0.), tensor(1.)))
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))
  mu = a + bR * Marriage_s

  divorce = pyro.sample("divorce", dist.Normal(mu, sigma))
  return divorce 

def link(laplace_guide, data, num_samples):
  pred = pyro.infer.Predictive(laplace_guide, num_samples=num_samples)
  samples = pred.get_samples()
  mu = samples["bR"].detach().reshape((num_samples,1))*data.reshape(1, data.shape[0]) + samples["a"].detach().reshape((num_samples,1))
  return mu

divorce = tensor(d.Divorce, dtype=torch.float)
Marriage_s = tensor(d['Marriage_s'], dtype=torch.float)

conditioned52 = pyro.condition(model52, data={"divorce": divorce})
guide52 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned52)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned52,
    guide=guide52,
    optim=pyro.optim.Adam({"lr": 0.1}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(Marriage_s) for t in range(num_steps)]
plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")
plt.tight_layout()
laplace_guide52 = guide52.laplace_approximation(Marriage_s)
pred52 = pyro.infer.Predictive(laplace_guide52, num_samples=1000)  


mam_seq = torch.linspace(-3., 3.5, steps=30)
mu = link(laplace_guide52, mam_seq, 1000)
mu_PI = stats.pi(mu, dim=0, prob=0.89)

fig, ax = plt.subplots()
ax.scatter(waffle_divorce['Marriage_s'], waffle_divorce['Divorce'])

x = waffle_divorce['Marriage_s'].sort_values()
y = tensor(waffle_divorce['Marriage_s'].sort_values()) * pred52.get_samples()["bR"].detach().mean() + pred52.get_samples()["a"].detach().mean()
ax.scatter(waffle_divorce['Marriage_s'], waffle_divorce['Divorce'])
ax.plot(x,y)
ax.fill_between(mam_seq, mu_PI[0], mu_PI[1], color='gray')
ax.set_xlabel("Marriage.s")
ax.set_ylabel('Divorce')
fig.tight_layout()
../_images/05_Chapter05_9_0.png ../_images/05_Chapter05_9_1.png

4.3. Code 5.4

def model53(Marriage_s, MedianAgeMarriage_s):
  
  a = pyro.sample("a", dist.Normal(tensor(10.), tensor(10.)))
  bA = pyro.sample("bA", dist.Normal(tensor(0.), tensor(1.)))
  bR = pyro.sample("bR", dist.Normal(tensor(0.), tensor(1.)))
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))
  mu = a + bR * Marriage_s + bA * MedianAgeMarriage_s

  divorce = pyro.sample("divorce", dist.Normal(mu, sigma))
  return divorce 

divorce = tensor(d.Divorce, dtype=torch.float)
Marriage_s = tensor(d['Marriage_s'], dtype=torch.float)
MedianAgeMarriage_s = tensor(d['MedianAgeMarriage_s'], dtype=torch.float)

conditioned53 = pyro.condition(model53, data={"divorce": divorce})
guide53 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned53)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned53,
    guide=guide53,
    optim=pyro.optim.Adam({"lr": 0.1}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(Marriage_s, MedianAgeMarriage_s) for t in range(num_steps)]
laplace_guide53 = guide53.laplace_approximation(Marriage_s, MedianAgeMarriage_s)
pred53 = pyro.infer.Predictive(laplace_guide53, num_samples=1000)  
samples53 = pred53.get_samples()
post53 = summary(samples53) 
pred53
Predictive(
  (model): AutoMultivariateNormal()
)

4.4. Code 5.5

precis = summary(pred53.get_samples())
precis = precis.drop(columns = ["_AutoMultivariateNormal_latent"]).T
fig, ax = plt.subplots()
sns.pointplot(precis["mean"], precis.index, join=False, ax=ax)
for i, param in enumerate(precis.index):
  sns.lineplot(np.hstack(precis.loc[param, ["5.0%", "95.0%"]].values), [i, i], color="k", ax=ax)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-13880436d307> in <module>
      2 precis = precis.drop(columns = ["_AutoMultivariateNormal_latent"]).T
      3 fig, ax = plt.subplots()
----> 4 sns.pointplot(precis["mean"], precis.index, join=False, ax=ax)
      5 for i, param in enumerate(precis.index):
      6   sns.lineplot(np.hstack(precis.loc[param, ["5.0%", "95.0%"]].values), [i, i], color="k", ax=ax)

/opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages/seaborn/_decorators.py in inner_f(*args, **kwargs)
     44             )
     45         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 46         return f(**kwargs)
     47     return inner_f
     48 

/opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages/seaborn/categorical.py in pointplot(x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, seed, markers, linestyles, dodge, join, scale, orient, color, palette, errwidth, capsize, ax, **kwargs)
   3363                             estimator, ci, n_boot, units, seed,
   3364                             markers, linestyles, dodge, join, scale,
-> 3365                             orient, color, palette, errwidth, capsize)
   3366 
   3367     if ax is None:

/opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages/seaborn/categorical.py in __init__(self, x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, seed, markers, linestyles, dodge, join, scale, orient, color, palette, errwidth, capsize)
   1654         """Initialize the plotter."""
   1655         self.establish_variables(x, y, hue, data, orient,
-> 1656                                  order, hue_order, units)
   1657         self.establish_colors(color, palette, 1)
   1658         self.estimate_statistic(estimator, ci, n_boot, seed)

/opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages/seaborn/categorical.py in establish_variables(self, x, y, hue, data, orient, order, hue_order, units)
    155             # Figure out the plotting orientation
    156             orient = infer_orient(
--> 157                 x, y, orient, require_numeric=self.require_numeric
    158             )
    159 

/opt/hostedtoolcache/Python/3.7.9/x64/lib/python3.7/site-packages/seaborn/_core.py in infer_orient(x, y, orient, require_numeric)
   1327     elif require_numeric and "numeric" not in (x_type, y_type):
   1328         err = "Neither the `x` nor `y` variable appears to be numeric."
-> 1329         raise TypeError(err)
   1330 
   1331     else:

TypeError: Neither the `x` nor `y` variable appears to be numeric.
../_images/05_Chapter05_13_1.png

4.5. Code 5.6

def model54(MedianAgeMarriage_s):
  
  a = pyro.sample("a", dist.Normal(tensor(0.), tensor(10.)))
  b = pyro.sample("bA", dist.Normal(tensor(0.), tensor(1.)))  
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))
  mu = a + b * MedianAgeMarriage_s

  marriage_s = pyro.sample("marriage", dist.Normal(mu, sigma))
  return marriage_s

conditioned54 = pyro.condition(model54, data={"marriage": Marriage_s})
guide54 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned54)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned54,
    guide=guide54,
    optim=pyro.optim.Adam({"lr": 0.1}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(MedianAgeMarriage_s) for t in range(num_steps)]
laplace_guide54 = guide54.laplace_approximation(MedianAgeMarriage_s)
pred54 = pyro.infer.Predictive(laplace_guide54, num_samples=1000)  
precis = summary(pred54.get_samples()) 

4.6. Code 5.7

mu = precis.loc["mean", "a"]  + precis.loc["mean", "bA"]*waffle_divorce['MedianAgeMarriage_s']
m_resid = waffle_divorce['MedianAgeMarriage_s'] - mu

4.7. Code 5.8

fig, ax = plt.subplots()
ax.scatter(waffle_divorce['MedianAgeMarriage_s'], waffle_divorce['Marriage_s'])

x = waffle_divorce['MedianAgeMarriage_s'].sort_values().unique()
y = precis.loc["mean", "a"]  + precis.loc["mean", "bA"]*x

ax.plot(x,y)

for i in range(len(m_resid)):
  x = waffle_divorce['MedianAgeMarriage_s'][i]
  y = waffle_divorce['Marriage_s'][i]
  sns.lineplot([x, x], [mu[i], y], color="k", alpha=1)
ax.set_xlabel('MedianMarriage_s')
ax.set_ylabel('Marriage_s')
fig.tight_layout()
../_images/05_Chapter05_19_0.png

4.8. Code 5.9

def sim(post, A, R, num_samples=1000):
  mu = post["bR"].detach().reshape((num_samples,1))*R.reshape(1, R.shape[0])
  mu += post["a"].detach().reshape((num_samples,1)) 
  mu += post["bA"].detach().reshape((num_samples,1))*A  
  sigma = post["sigma"].reshape(num_samples,1)
  divorce = dist.Normal(mu, 
                        sigma).sample()
  return divorce


def link(samples, A, R, num_samples=1000):

  mu = samples["bR"].detach().reshape((num_samples,1))*R.reshape(1, R.shape[0])
  mu += samples["a"].detach().reshape((num_samples,1)) 
  mu += samples["bA"].detach().reshape((num_samples,1))*A  
  return mu

A_avg = d['MedianAgeMarriage_s'].mean()
R_seq = torch.linspace(-3., 3., steps=30)

samples = samples53
mu = link(samples, A_avg, R_seq, num_samples=1000)
mu_mean = mu.mean(0)

mu_PI = stats.pi(mu, dim=0, prob=0.89)

R_sim = sim(samples, A_avg, R_seq)
R_PI = stats.pi(R_sim, dim=0, prob=0.89)


fig, ax = plt.subplots()
sns.lineplot(R_seq.numpy(), mu_mean.numpy(), ax=ax)
ax.fill_between(R_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
ax.fill_between(R_seq, R_PI[0], R_PI[1], color="k", alpha=0.2)
ax.set_xlabel('Marriage.s')
ax.set_ylabel('Divorce')
fig.tight_layout()
../_images/05_Chapter05_21_0.png

4.9. Code 5.10

def sim(post, A, R, num_samples=1000):
  
  mu = post["bA"].detach().reshape((num_samples,1))*A.reshape(1, A.shape[0]) 
  mu += post["bR"].detach().reshape((num_samples,1))*R
  mu += post["a"].detach().reshape((num_samples,1)) 
  
  sigma = post["sigma"].reshape(num_samples,1)
  divorce = dist.Normal(mu, 
                        sigma).sample()
  return divorce


def link(samples, A, R, num_samples=1000):

  mu = samples["bA"].detach().reshape((num_samples,1))*A.reshape(1, A.shape[0])    
  mu += samples["a"].detach().reshape((num_samples,1)) 
  mu += samples["bR"].detach().reshape((num_samples,1))*R
  return mu


R_avg = d['Marriage_s'].mean()
A_seq = torch.linspace(-3., 3.5, steps=30)

samples = samples53
mu = link(samples, A_seq, R_avg, num_samples=1000)
mu_mean = mu.mean(0)
mu_PI = stats.pi(mu, dim=0, prob=0.89)

A_sim = sim(samples, A_seq, R_avg)
A_PI = stats.pi(A_sim, dim=0, prob=0.89)


fig, ax = plt.subplots()
sns.lineplot(A_seq.numpy(), mu_mean.numpy(), ax=ax)
ax.fill_between(A_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
ax.fill_between(A_seq, A_PI[0], A_PI[1], color="k", alpha=0.2)
ax.set_xlabel('MedianAgeMarriage.s')
ax.set_ylabel('Divorce')
fig.tight_layout()
../_images/05_Chapter05_23_0.png

4.10. Code 5.11

def sim(post, A, R, num_samples=1000):
  
  mu = post["bA"].detach().reshape((num_samples,1))*A.reshape(1, A.shape[0]) 
  mu += post["bR"].detach().reshape((num_samples,1))*R.reshape(1, R.shape[0]) 
  mu += post["a"].detach().reshape((num_samples,1)) 
  
  sigma = post["sigma"].reshape(num_samples,1)
  divorce = dist.Normal(mu, 
                        sigma).sample()
  return divorce


#mu = post53.loc["mean", "a"]  + post53.loc["mean", "bA"]*waffle_divorce['MedianAgeMarriage_s'] + post53.loc["mean", "bR"]*waffle_divorce['Marriage_s']
#mu = tensor(mu, dtype=torch.float)
num_samples = 1000
A = tensor(waffle_divorce['MedianAgeMarriage_s'], dtype=torch.float)
R = tensor(waffle_divorce['Marriage_s'], dtype=torch.float)
mu = samples["bA"].detach().reshape((num_samples,1))*A.reshape(1, A.shape[0]) 
mu += samples["bR"].detach().reshape((num_samples,1))*R.reshape(1, R.shape[0]) 
mu += samples["a"].detach().reshape((num_samples,1)) 

mu_mean = mu.mean(0)

mu_PI = stats.pi(mu, dim=0, prob=0.89)

divorce_sim = sim(samples, waffle_divorce['MedianAgeMarriage_s'].values, waffle_divorce['Marriage_s'].values)
divorce_PI = stats.pi(divorce_sim, dim=0, prob=0.89)

4.11. Code 5.12

fig, ax = plt.subplots()
ax.scatter(d.Divorce, mu_mean)
ax.set_xlabel('Observed divorce')
ax.set_ylabel('Predicted divorce')
x = torch.linspace(6, 14, 101)
sns.lineplot(x, x, color="k", ax=ax)
ax.lines[-1].set_linestyle("--")
for i in range(d.shape[0]):
    sns.lineplot(divorce[i].repeat(2), mu_PI[:, i], color="k", ax=ax)
../_images/05_Chapter05_27_0.png

4.12. Code 5.13

identify = mu_mean.sort(descending=True)[1][:2].numpy()
for i in identify:
  ax.annotate(d["Loc"][i], (divorce[i].numpy(), mu_mean[i].numpy()), xytext=(-35, -5),
              textcoords="offset pixels")
fig
../_images/05_Chapter05_29_0.png

4.13. Code 5.14

divorce_resid = tensor(d.Divorce) - mu_mean
divorce_resid = divorce_resid.numpy()
o = divorce_resid.argsort()

fig, ax = plt.subplots(figsize=(8, 10))
ax.scatter(divorce_resid[o], d.Loc[o])
ax.set(xlim=(-6, 5))
ax.axvline(x=0, c="k", alpha=0.2)
for i in range(d.shape[0]):
    j = o[i]
    sns.lineplot(divorce[j] - mu_PI[:, j], [i, i], color="k", ax=ax)
    sns.scatterplot(divorce[j] - divorce_PI[:, j], [i, i], color="gray", marker="+", ax=ax)
fig.tight_layout()
../_images/05_Chapter05_31_0.png

4.14. Code 5.15

N = 100
x_real = dist.Normal(0,1).rsample(torch.Size([N]))
x_spur = dist.Normal(x_real,1).rsample(torch.Size([N]))
y = dist.Normal(x_real,1).rsample(torch.Size([N]))
d = pd.DataFrame([x_real.numpy(), x_spur.numpy(),  y.numpy()]).T
d.columns = ["x_real", "x_spur", "y"]

4.15. Code 5.16

milk_df = pd.read_csv("https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/milk.csv", sep=";")
d = milk_df 
d
clade species kcal.per.g perc.fat perc.protein perc.lactose mass neocortex.perc
0 Strepsirrhine Eulemur fulvus 0.49 16.60 15.42 67.98 1.95 55.16
1 Strepsirrhine E macaco 0.51 19.27 16.91 63.82 2.09 NaN
2 Strepsirrhine E mongoz 0.46 14.11 16.85 69.04 2.51 NaN
3 Strepsirrhine E rubriventer 0.48 14.91 13.18 71.91 1.62 NaN
4 Strepsirrhine Lemur catta 0.60 27.28 19.50 53.22 2.19 NaN
5 New World Monkey Alouatta seniculus 0.47 21.22 23.58 55.20 5.25 64.54
6 New World Monkey A palliata 0.56 29.66 23.46 46.88 5.37 64.54
7 New World Monkey Cebus apella 0.89 53.41 15.80 30.79 2.51 67.64
8 New World Monkey Saimiri boliviensis 0.91 46.08 23.34 30.58 0.71 NaN
9 New World Monkey S sciureus 0.92 50.58 22.33 27.09 0.68 68.85
10 New World Monkey Cebuella pygmaea 0.80 41.35 20.85 37.80 0.12 58.85
11 New World Monkey Callimico goeldii 0.46 3.93 25.30 70.77 0.47 61.69
12 New World Monkey Callithrix jacchus 0.71 38.38 20.09 41.53 0.32 60.32
13 New World Monkey Leontopithecus rosalia 0.71 36.90 21.27 41.83 0.60 NaN
14 Old World Monkey Chlorocebus pygerythrus 0.73 39.17 14.65 46.18 3.47 NaN
15 Old World Monkey Miopithecus talpoin 0.68 40.15 18.08 41.77 1.55 69.97
16 Old World Monkey M fuscata 0.72 53.05 13.00 33.95 7.08 NaN
17 Old World Monkey M mulatta 0.97 55.51 13.17 31.32 3.24 70.41
18 Old World Monkey M sinica 0.79 48.90 13.91 37.19 7.94 NaN
19 Old World Monkey Papio spp 0.84 54.31 10.97 34.72 12.30 73.40
20 Ape Nomascus concolor 0.48 15.96 12.52 71.52 7.59 NaN
21 Ape Hylobates lar 0.62 34.51 12.57 52.92 5.37 67.53
22 Ape Symphalangus syndactylus 0.51 26.42 13.46 60.12 10.72 NaN
23 Ape Pongo pygmaeus 0.54 37.78 7.37 54.85 35.48 71.26
24 Ape Gorilla gorilla gorilla 0.49 27.18 16.29 56.53 79.43 72.60
25 Ape G gorilla beringei 0.53 30.59 20.77 48.64 97.72 NaN
26 Ape Pan paniscus 0.48 21.18 11.68 67.14 40.74 70.24
27 Ape P troglodytes 0.55 36.84 9.54 53.62 33.11 76.30
28 Ape Homo sapiens 0.71 50.49 9.84 39.67 54.95 75.49

4.16. Code 5.17

def model55(neocortex_perc):
  a = pyro.sample("a", dist.Normal(tensor(0.), tensor(100.)))
  bn = pyro.sample("bn", dist.Normal(tensor(0.), tensor(1.)))  
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))
  mu = a + bn * neocortex_perc
  kcal_per_g = pyro.sample("kcal_per_g", dist.Normal(mu.float(), sigma.float()))
  return kcal_per_g

dcc = d.dropna()
kcal_per_g = tensor(dcc['kcal.per.g'], dtype=torch.float)
neocortex_perc = tensor(dcc['neocortex.perc'], dtype=torch.float)

conditioned55 = pyro.condition(model55, data={"kcal_per_g": kcal_per_g})
guide55 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned55)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned55,
    guide=guide55,
    optim=pyro.optim.Adam({"lr": 0.01}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(neocortex_perc.float()) for t in range(num_steps)]
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7efed243de48>]
../_images/05_Chapter05_37_1.png

4.17. Code 5.18

d['neocortex.perc']
0     55.16
1       NaN
2       NaN
3       NaN
4       NaN
5     64.54
6     64.54
7     67.64
8       NaN
9     68.85
10    58.85
11    61.69
12    60.32
13      NaN
14      NaN
15    69.97
16      NaN
17    70.41
18      NaN
19    73.40
20      NaN
21    67.53
22      NaN
23    71.26
24    72.60
25      NaN
26    70.24
27    76.30
28    75.49
Name: neocortex.perc, dtype: float64

4.18. Code 5.19

dcc = d.dropna()

4.19. Code 5.20

def model55(data):
  a = pyro.sample("a", dist.Normal(tensor(0.), tensor(100.)))
  bn = pyro.sample("bn", dist.Normal(tensor(0.), tensor(1.)))  
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(1.)))
  mu = a + bn * data
  kcal_per_g = pyro.sample("kcal_per_g", dist.Normal(mu, sigma))
  return kcal_per_g

kcal_per_g = tensor(dcc['kcal.per.g'], dtype=torch.float)
neocortex_perc = tensor(dcc['neocortex.perc'], dtype=torch.float)

conditioned55 = pyro.condition(model55, data={"kcal_per_g": kcal_per_g})
guide55 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned55)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned55,
    guide=guide55,
    optim=pyro.optim.Adam({"lr": 0.005}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(neocortex_perc.float()) for t in range(num_steps)]
plt.plot(losses)

laplace_guide55 = guide55.laplace_approximation(neocortex_perc)
pred55 = pyro.infer.Predictive(laplace_guide55, num_samples=1000)  
precis55 = summary(pred55.get_samples()) 
../_images/05_Chapter05_43_0.png
def model(x, y):
    a = pyro.sample("a", dist.Normal(tensor(0.), tensor(100.)))
    bn = pyro.sample("bn", dist.Normal(tensor(0.), tensor(1.)))
    mu = a + bn * x
    sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(1.)))
    with pyro.plate("plate"):
        pyro.sample("kcal_per_g", dist.Normal(mu, sigma), obs=y)

def model55(x,y):
  a = pyro.sample("a", dist.Normal(tensor(0.), tensor(100.)))
  bn = pyro.sample("bn", dist.Normal(tensor(0.), tensor(1.)))  
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(1.)))
  mu = a + bn * x
  kcal_per_g = pyro.sample("kcal_per_g", dist.Normal(mu, sigma))#, obs=y)
  return kcal_per_g

kcal_per_g = tensor(dcc['kcal.per.g'], dtype=torch.float).float()
neocortex_perc = tensor(dcc['neocortex.perc'], dtype=torch.float).float()
pyro.clear_param_store()

conditioned55 = pyro.condition(model55, data={"kcal_per_g": kcal_per_g})
guide55 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned55)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned55,
    guide=guide55,
    optim=pyro.optim.Adam({"lr": 0.005}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]
plt.plot(losses)

laplace_guide55 = guide55.laplace_approximation(neocortex_perc, kcal_per_g)
pred55 = pyro.infer.Predictive(laplace_guide55, num_samples=1000)  
precis55 = summary(pred55.get_samples()) 
../_images/05_Chapter05_44_0.png
precis55
_AutoMultivariateNormal_latent a bn sigma
mean [38.724445, -0.5592735, 1.1130564] 38.724472 -0.5592732 0.7499051
std [5.122513, 0.07533453, 0.24596252] 5.122513 0.07533453 0.045710217
5.0% [30.631195, -0.6852666, 0.7258537] 30.631195 -0.6852666 0.67389476
95.0% [47.238846, -0.43942103, 1.5153502] 47.238846 -0.43942103 0.81985277
neocortex_perc
tensor([55.1600, 64.5400, 64.5400, 67.6400, 68.8500, 58.8500, 61.6900, 60.3200,
        69.9700, 70.4100, 73.4000, 67.5300, 71.2600, 72.6000, 70.2400, 76.3000,
        75.4900])
"""def model(x, y):
    a = pyro.sample("a", dist.Normal(tensor(0.), tensor(100.)))
    bn = pyro.sample("bn", dist.Normal(tensor(0.), tensor(1.)))
    mu = a + bn * x
    sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(1.)))
    with pyro.plate("plate"):
        pyro.sample("kcal_per_g", dist.Normal(mu, sigma), obs=y)


conditioned = pyro.condition(model, data={"kcal_per_g": kcal_per_g})
guide = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned,
    guide=guide,
    optim=pyro.optim.Adam({"lr": 0.005}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]
plt.plot(losses)

laplace_guide = guide.laplace_approximation(neocortex_perc, kcal_per_g)
pred = pyro.infer.Predictive(laplace_guide, num_samples=1000)"""
'def model(x, y):\n    a = pyro.sample("a", dist.Normal(tensor(0.), tensor(100.)))\n    bn = pyro.sample("bn", dist.Normal(tensor(0.), tensor(1.)))\n    mu = a + bn * x\n    sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(1.)))\n    with pyro.plate("plate"):\n        pyro.sample("kcal_per_g", dist.Normal(mu, sigma), obs=y)\n\n\nconditioned = pyro.condition(model, data={"kcal_per_g": kcal_per_g})\nguide = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned)\npyro.clear_param_store()\n\nsvi = pyro.infer.SVI(\n    model=conditioned,\n    guide=guide,\n    optim=pyro.optim.Adam({"lr": 0.005}),\n    loss=pyro.infer.Trace_ELBO(),\n)\nnum_steps = 5000\nlosses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]\nplt.plot(losses)\n\nlaplace_guide = guide.laplace_approximation(neocortex_perc, kcal_per_g)\npred = pyro.infer.Predictive(laplace_guide, num_samples=1000)'

4.20. Code 5.21

"""pred_summary55 = summary(pred.get_samples())
pred_summary55"""
'pred_summary55 = summary(pred.get_samples())\npred_summary55'

4.21. Code 5.22

"""pred_summary55.loc["mean", "bn"] * (76-55)"""
'pred_summary55.loc["mean", "bn"] * (76-55)'

4.22. Code 5.23

4.23. Code 5.24

dcc['log_mass'] = np.log(dcc['mass'])

4.24. Code 5.25

"""
def model56(logmass, kcal):
  a = pyro.sample("a", dist.Normal(0, 100))
  bm = pyro.sample("bn", dist.Normal(0, 1))
  sigma = pyro.sample("sigma", dist.Uniform(0, 1))

  mu = a + bm * logmass
  with pyro.plate("data"):
    pyro.sample("kcal_per_g", dist.Normal(mu, sigma))

kcal_per_g = tensor(dcc['kcal.per.g'], dtype=torch.float)
neocortex_perc = tensor(dcc['neocortex.perc'], dtype=torch.float)

conditioned56 = pyro.condition(model56, data={"kcal_per_g": kcal_per_g})
guide56 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned56)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned56,
    guide=guide56,
    optim=pyro.optim.Adam({"lr": 0.005}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]
plt.plot(losses)

laplace_guide56 = guide56.laplace_approximation(neocortex_perc, kcal_per_g)
pred56 = pyro.infer.Predictive(laplace_guide56, num_samples=1000)  
precis56 = summary(pred56.get_samples()) 

precis56
"""
'\ndef model56(logmass, kcal):\n  a = pyro.sample("a", dist.Normal(0, 100))\n  bm = pyro.sample("bn", dist.Normal(0, 1))\n  sigma = pyro.sample("sigma", dist.Uniform(0, 1))\n\n  mu = a + bm * logmass\n  with pyro.plate("data"):\n    pyro.sample("kcal_per_g", dist.Normal(mu, sigma))\n\nkcal_per_g = tensor(dcc[\'kcal.per.g\'], dtype=torch.float)\nneocortex_perc = tensor(dcc[\'neocortex.perc\'], dtype=torch.float)\n\nconditioned56 = pyro.condition(model56, data={"kcal_per_g": kcal_per_g})\nguide56 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned56)\npyro.clear_param_store()\n\nsvi = pyro.infer.SVI(\n    model=conditioned56,\n    guide=guide56,\n    optim=pyro.optim.Adam({"lr": 0.005}),\n    loss=pyro.infer.Trace_ELBO(),\n)\nnum_steps = 5000\nlosses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]\nplt.plot(losses)\n\nlaplace_guide56 = guide56.laplace_approximation(neocortex_perc, kcal_per_g)\npred56 = pyro.infer.Predictive(laplace_guide56, num_samples=1000)  \nprecis56 = summary(pred56.get_samples()) \n\nprecis56\n'
def model56(logmass, kcal):
  a = pyro.sample("a", dist.Normal(tensor(0.), tensor(100.)))
  bm = pyro.sample("bm", dist.Normal(tensor(0.), tensor(1.)))
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(1.)))

  mu = a + bm * logmass
  #print(mu)
  kcal_per_g = pyro.sample("kcal_per_g", dist.Normal(mu.float(), sigma.float()))
  return kcal_per_g


def guide56(logmass, kcal):
  a_param = pyro.param("a_param", tensor(10.))
  bm_param = pyro.param("bm_param", tensor(.5))
  sigma_param = pyro.param("sigma_param", tensor(.5))
  return pyro.sample("a", dist.Delta(a_param)), pyro.sample("bm", dist.Delta(bm_param)),  pyro.sample("sigma", dist.Delta(sigma_param))


conditioned56 = pyro.condition(model56, data={"kcal_per_g": kcal_per_g})
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned56,
    guide=guide56,
    optim=pyro.optim.Adam({"lr": 0.001}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7efed21b0f98>]
../_images/05_Chapter05_58_1.png
pyro.param("bm_param")
tensor(-0.1276, requires_grad=True)
"""

kcal_per_g = tensor(dcc['kcal.per.g'], dtype=torch.float)
neocortex_perc = tensor(dcc['neocortex.perc'], dtype=torch.float)


guide56 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned56)
pyro.clear_param_store()

svi = pyro.infer.SVI(
    model=conditioned56,
    guide=guide56,
    optim=pyro.optim.Adam({"lr": 0.005}),
    loss=pyro.infer.Trace_ELBO(),
)
num_steps = 5000
losses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]
plt.plot(losses)

laplace_guide56 = guide56.laplace_approximation(neocortex_perc, kcal_per_g)
pred56 = pyro.infer.Predictive(laplace_guide56, num_samples=1000)  
precis56 = summary(pred56.get_samples()) 

precis56
"""
'\n\nkcal_per_g = tensor(dcc[\'kcal.per.g\'], dtype=torch.float)\nneocortex_perc = tensor(dcc[\'neocortex.perc\'], dtype=torch.float)\n\n\nguide56 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned56)\npyro.clear_param_store()\n\nsvi = pyro.infer.SVI(\n    model=conditioned56,\n    guide=guide56,\n    optim=pyro.optim.Adam({"lr": 0.005}),\n    loss=pyro.infer.Trace_ELBO(),\n)\nnum_steps = 5000\nlosses = [svi.step(neocortex_perc, kcal_per_g) for t in range(num_steps)]\nplt.plot(losses)\n\nlaplace_guide56 = guide56.laplace_approximation(neocortex_perc, kcal_per_g)\npred56 = pyro.infer.Predictive(laplace_guide56, num_samples=1000)  \nprecis56 = summary(pred56.get_samples()) \n\nprecis56\n'

4.25. Code 5.26

4.26. Code 5.27

4.27. Code 5.28

4.28. Code 5.29

N = 100
height = dist.Normal(10., 2.).rsample(torch.Size([N]))
leg_prop = dist.Uniform(0.4, 0.5).rsample(torch.Size([N]))

leg_left = leg_prop*height + dist.Normal(0., 0.02).rsample(torch.Size([N]))
leg_right = leg_prop*height + dist.Normal(0., 0.02).rsample(torch.Size([N]))

d = pd.DataFrame([height, leg_left, leg_right]).T
d.columns = ['height', 'leg_left', 'leg_right']
d['leg_left']
0     tensor(5.7180)
1     tensor(4.9624)
2     tensor(4.3404)
3     tensor(5.4653)
4     tensor(4.8009)
           ...      
95    tensor(4.0597)
96    tensor(3.1883)
97    tensor(4.0643)
98    tensor(5.9066)
99    tensor(4.1089)
Name: leg_left, Length: 100, dtype: object

4.29. Code 5.30

def model58(ll, lr):
  a = pyro.sample("a", dist.Normal(tensor(10.), tensor(100.)))
  bl = pyro.sample("bl", dist.Normal(tensor(2.), tensor(10.)))
  br = pyro.sample("br", dist.Normal(tensor(2.), tensor(10.)))
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))

  mu = a + bl*ll + br*lr
  with pyro.plate("height"):
    pyro.sample("height", dist.Normal(mu, sigma))
  #return height


conditioned58 = pyro.condition(model58, data={"height": tensor(height, dtype=torch.float)})
guide58 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned58)
pyro.clear_param_store()

svi = pyro.infer.SVI(model=conditioned58,
                     guide=guide58,
                     optim=pyro.optim.Adam({"lr": 0.01}),
                     loss=pyro.infer.Trace_ELBO())
num_steps = 5000
losses = [svi.step(tensor(leg_left, dtype=torch.float), tensor(leg_right, dtype=torch.float)) for t in range(num_steps)]
plt.plot(losses)

laplace_guide58 = guide58.laplace_approximation(tensor(leg_left, dtype=torch.float), tensor(leg_right, dtype=torch.float))
pred58 = pyro.infer.Predictive(laplace_guide58, num_samples=1000)  
precis58 = summary(pred58.get_samples())
../_images/05_Chapter05_69_0.png
precis58
_AutoMultivariateNormal_latent a bl br sigma
mean [1.0337458, 3.141063, -1.1763383, -2.7893252] 1.033746 3.1410625 -1.1763381 0.58046764
std [0.2597808, 1.6465611, 1.6511321, 0.07707023] 0.2597808 1.6465611 1.6511321 0.042131346
5.0% [0.61492366, 0.4741156, -3.8326325, -2.9146008] 0.61492366 0.4741156 -3.8326325 0.5143649
95.0% [1.4572909, 5.7731295, 1.4903916, -2.6606252] 1.4572909 5.7731295 1.4903916 0.65337145

4.30. Code 5.34

"""
def model59(ll):
  a = pyro.sample("a", dist.Normal(tensor(10.), tensor(100.)))
  bl = pyro.sample("bl", dist.Normal(tensor(2.), tensor(10.)))
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))

  mu = a + bl*ll
  height = pyro.sample("height", dist.Normal(mu, sigma))
  return height


conditioned59 = pyro.condition(model59, data={"height": tensor(height, dtype=torch.float)})
guide59 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned59)
pyro.clear_param_store()

svi = pyro.infer.SVI(model=conditioned59,
                     guide=guide59,
                     optim=pyro.optim.Adam({"lr": 0.01}),
                     loss=pyro.infer.Trace_ELBO())
num_steps = 5000
losses = [svi.step(tensor(leg_left, dtype=torch.float)) for t in range(num_steps)]
plt.plot(losses)

laplace_guide59 = guide59.laplace_approximation(tensor(leg_left, dtype=torch.float))
pred59 = pyro.infer.Predictive(laplace_guide59, num_samples=1000)  
precis59 = summary(pred59.get_samples())
precis59
"""
'\ndef model59(ll):\n  a = pyro.sample("a", dist.Normal(tensor(10.), tensor(100.)))\n  bl = pyro.sample("bl", dist.Normal(tensor(2.), tensor(10.)))\n  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))\n\n  mu = a + bl*ll\n  height = pyro.sample("height", dist.Normal(mu, sigma))\n  return height\n\n\nconditioned59 = pyro.condition(model59, data={"height": tensor(height, dtype=torch.float)})\nguide59 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned59)\npyro.clear_param_store()\n\nsvi = pyro.infer.SVI(model=conditioned59,\n                     guide=guide59,\n                     optim=pyro.optim.Adam({"lr": 0.01}),\n                     loss=pyro.infer.Trace_ELBO())\nnum_steps = 5000\nlosses = [svi.step(tensor(leg_left, dtype=torch.float)) for t in range(num_steps)]\nplt.plot(losses)\n\nlaplace_guide59 = guide59.laplace_approximation(tensor(leg_left, dtype=torch.float))\npred59 = pyro.infer.Predictive(laplace_guide59, num_samples=1000)  \nprecis59 = summary(pred59.get_samples())\nprecis59\n'

4.31. Code 5.35

milk_df = pd.read_csv("https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/milk.csv", sep=";")
d = milk_df 
dmilk_df = pd.read_csv("https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/milk.csv", sep=";")
d = milk_df 
d
clade species kcal.per.g perc.fat perc.protein perc.lactose mass neocortex.perc
0 Strepsirrhine Eulemur fulvus 0.49 16.60 15.42 67.98 1.95 55.16
1 Strepsirrhine E macaco 0.51 19.27 16.91 63.82 2.09 NaN
2 Strepsirrhine E mongoz 0.46 14.11 16.85 69.04 2.51 NaN
3 Strepsirrhine E rubriventer 0.48 14.91 13.18 71.91 1.62 NaN
4 Strepsirrhine Lemur catta 0.60 27.28 19.50 53.22 2.19 NaN
5 New World Monkey Alouatta seniculus 0.47 21.22 23.58 55.20 5.25 64.54
6 New World Monkey A palliata 0.56 29.66 23.46 46.88 5.37 64.54
7 New World Monkey Cebus apella 0.89 53.41 15.80 30.79 2.51 67.64
8 New World Monkey Saimiri boliviensis 0.91 46.08 23.34 30.58 0.71 NaN
9 New World Monkey S sciureus 0.92 50.58 22.33 27.09 0.68 68.85
10 New World Monkey Cebuella pygmaea 0.80 41.35 20.85 37.80 0.12 58.85
11 New World Monkey Callimico goeldii 0.46 3.93 25.30 70.77 0.47 61.69
12 New World Monkey Callithrix jacchus 0.71 38.38 20.09 41.53 0.32 60.32
13 New World Monkey Leontopithecus rosalia 0.71 36.90 21.27 41.83 0.60 NaN
14 Old World Monkey Chlorocebus pygerythrus 0.73 39.17 14.65 46.18 3.47 NaN
15 Old World Monkey Miopithecus talpoin 0.68 40.15 18.08 41.77 1.55 69.97
16 Old World Monkey M fuscata 0.72 53.05 13.00 33.95 7.08 NaN
17 Old World Monkey M mulatta 0.97 55.51 13.17 31.32 3.24 70.41
18 Old World Monkey M sinica 0.79 48.90 13.91 37.19 7.94 NaN
19 Old World Monkey Papio spp 0.84 54.31 10.97 34.72 12.30 73.40
20 Ape Nomascus concolor 0.48 15.96 12.52 71.52 7.59 NaN
21 Ape Hylobates lar 0.62 34.51 12.57 52.92 5.37 67.53
22 Ape Symphalangus syndactylus 0.51 26.42 13.46 60.12 10.72 NaN
23 Ape Pongo pygmaeus 0.54 37.78 7.37 54.85 35.48 71.26
24 Ape Gorilla gorilla gorilla 0.49 27.18 16.29 56.53 79.43 72.60
25 Ape G gorilla beringei 0.53 30.59 20.77 48.64 97.72 NaN
26 Ape Pan paniscus 0.48 21.18 11.68 67.14 40.74 70.24
27 Ape P troglodytes 0.55 36.84 9.54 53.62 33.11 76.30
28 Ape Homo sapiens 0.71 50.49 9.84 39.67 54.95 75.49

4.32. Code 5.36

def model510(pf):
  a = pyro.sample("a", dist.Normal(tensor(0.6), tensor(10.)))
  bf = pyro.sample("bf", dist.Normal(tensor(2.), tensor(1.)))
  sigma = pyro.sample("sigma", dist.Uniform(tensor(0.), tensor(10.)))

  mu = a + bf*pf
  kcal_per_g = pyro.sample("kcal_per_g", dist.Normal(mu, sigma))
  return height


conditioned510 = pyro.condition(model510, data={"kcal_per_g": tensor(d['kcal.per.g'], dtype=torch.float)})
guide510 = pyro.infer.autoguide.AutoLaplaceApproximation(conditioned510)
pyro.clear_param_store()

svi = pyro.infer.SVI(model=conditioned510,
                     guide=guide510,
                     optim=pyro.optim.Adam({"lr": 0.01}),
                     loss=pyro.infer.Trace_ELBO())
num_steps = 5000
losses = [svi.step(tensor(d['perc.fat'], dtype=torch.float)) for t in range(num_steps)]
plt.plot(losses)

laplace_guide510 = guide510.laplace_approximation(tensor(d['perc.fat'], dtype=torch.float))
pred510 = pyro.infer.Predictive(laplace_guide510, num_samples=1000)  
precis510 = summary(pred510.get_samples())
precis510
_AutoMultivariateNormal_latent a bf sigma
mean [0.29973564, 0.010284884, -4.850477] 0.29973575 0.010284895 0.0784034
std [0.037797228, 0.0010413628, 0.14234424] 0.037797228 0.0010413628 0.010974918
5.0% [0.23592636, 0.008497991, -5.0775146] 0.23592636 0.008497991 0.06196748
95.0% [0.3620291, 0.011877116, -4.621042] 0.3620291 0.011877116 0.09746605
../_images/05_Chapter05_76_1.png

4.33. Code 5.37

4.34. Code 5.38

4.35. Code 5.39

4.36. Code 5.40

4.37. Code 5.41

4.38. Code 5.42

4.39. Code 5.43

4.40. Code 5.44

4.41. Code 5.45

4.42. Code 5.46

4.43. Code 5.47

4.44. Code 5.48

4.45. Code 5.49

4.46. Code 5.50

4.47. Code 5.51

4.48. Code 5.52

4.49. Code 5.53

4.50. Code 5.54

4.51. Code 5.55

4.52. Code 5.56

4.53. Code 5.57

4.54. Code 5.58

4.55. Code 5.59

4.56. Code 5.60

4.57. Code 5.61

4.58. Code 5.62