Stochastic SIR model with Python
Libraries
import numpy as np
import math
import pandas as pd
import pythran
%%writefile .pythranrc
[compiler]
include_dirs=/usr/include/openblas
Writing .pythranrc
%load_ext pythran.magic
np.random.seed(123)
Plain Python version
def sir(u,parms,t):
bet,gamm,iota,N,dt=parms
S,I,R,Y=u
lambd = bet*(I+iota)/N
ifrac = 1.0 - math.exp(-lambd*dt)
rfrac = 1.0 - math.exp(-gamm*dt)
infection = np.random.binomial(S,ifrac)
recovery = np.random.binomial(I,rfrac)
return [S-infection,I+infection-recovery,R+recovery,Y+infection]
def simulate():
parms = [0.1, 0.05, 0.01, 1000.0, 0.1]
tf = 200
tl = 2001
t = np.linspace(0,tf,tl)
S = np.zeros(tl)
I = np.zeros(tl)
R = np.zeros(tl)
Y = np.zeros(tl)
u = [999,1,0,0]
S[0],I[0],R[0],Y[0] = u
for j in range(1,tl):
u = sir(u,parms,t[j])
S[j],I[j],R[j],Y[j] = u
return {'t':t,'S':S,'I':I,'R':R,'Y':Y}
%timeit simulate()
8.89 ms ± 85.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
sir_out = pd.DataFrame(simulate())
sir_out
t | S | I | R | Y | |
---|---|---|---|---|---|
0 | 0.0 | 999.0 | 1.0 | 0.0 | 0.0 |
1 | 0.1 | 999.0 | 1.0 | 0.0 | 0.0 |
2 | 0.2 | 999.0 | 1.0 | 0.0 | 0.0 |
3 | 0.3 | 999.0 | 1.0 | 0.0 | 0.0 |
4 | 0.4 | 999.0 | 1.0 | 0.0 | 0.0 |
5 | 0.5 | 999.0 | 1.0 | 0.0 | 0.0 |
6 | 0.6 | 999.0 | 1.0 | 0.0 | 0.0 |
7 | 0.7 | 999.0 | 1.0 | 0.0 | 0.0 |
8 | 0.8 | 999.0 | 1.0 | 0.0 | 0.0 |
9 | 0.9 | 999.0 | 1.0 | 0.0 | 0.0 |
10 | 1.0 | 999.0 | 1.0 | 0.0 | 0.0 |
11 | 1.1 | 999.0 | 1.0 | 0.0 | 0.0 |
12 | 1.2 | 999.0 | 1.0 | 0.0 | 0.0 |
13 | 1.3 | 999.0 | 1.0 | 0.0 | 0.0 |
14 | 1.4 | 999.0 | 1.0 | 0.0 | 0.0 |
15 | 1.5 | 999.0 | 1.0 | 0.0 | 0.0 |
16 | 1.6 | 999.0 | 1.0 | 0.0 | 0.0 |
17 | 1.7 | 999.0 | 1.0 | 0.0 | 0.0 |
18 | 1.8 | 999.0 | 1.0 | 0.0 | 0.0 |
19 | 1.9 | 999.0 | 1.0 | 0.0 | 0.0 |
20 | 2.0 | 999.0 | 1.0 | 0.0 | 0.0 |
21 | 2.1 | 999.0 | 1.0 | 0.0 | 0.0 |
22 | 2.2 | 999.0 | 1.0 | 0.0 | 0.0 |
23 | 2.3 | 999.0 | 1.0 | 0.0 | 0.0 |
24 | 2.4 | 999.0 | 1.0 | 0.0 | 0.0 |
25 | 2.5 | 999.0 | 1.0 | 0.0 | 0.0 |
26 | 2.6 | 999.0 | 1.0 | 0.0 | 0.0 |
27 | 2.7 | 999.0 | 1.0 | 0.0 | 0.0 |
28 | 2.8 | 999.0 | 1.0 | 0.0 | 0.0 |
29 | 2.9 | 999.0 | 1.0 | 0.0 | 0.0 |
... | ... | ... | ... | ... | ... |
1971 | 197.1 | 285.0 | 62.0 | 653.0 | 714.0 |
1972 | 197.2 | 285.0 | 62.0 | 653.0 | 714.0 |
1973 | 197.3 | 285.0 | 62.0 | 653.0 | 714.0 |
1974 | 197.4 | 285.0 | 61.0 | 654.0 | 714.0 |
1975 | 197.5 | 285.0 | 60.0 | 655.0 | 714.0 |
1976 | 197.6 | 285.0 | 60.0 | 655.0 | 714.0 |
1977 | 197.7 | 285.0 | 59.0 | 656.0 | 714.0 |
1978 | 197.8 | 285.0 | 58.0 | 657.0 | 714.0 |
1979 | 197.9 | 285.0 | 57.0 | 658.0 | 714.0 |
1980 | 198.0 | 285.0 | 57.0 | 658.0 | 714.0 |
1981 | 198.1 | 285.0 | 57.0 | 658.0 | 714.0 |
1982 | 198.2 | 285.0 | 57.0 | 658.0 | 714.0 |
1983 | 198.3 | 285.0 | 57.0 | 658.0 | 714.0 |
1984 | 198.4 | 285.0 | 57.0 | 658.0 | 714.0 |
1985 | 198.5 | 285.0 | 56.0 | 659.0 | 714.0 |
1986 | 198.6 | 285.0 | 56.0 | 659.0 | 714.0 |
1987 | 198.7 | 285.0 | 55.0 | 660.0 | 714.0 |
1988 | 198.8 | 284.0 | 56.0 | 660.0 | 715.0 |
1989 | 198.9 | 284.0 | 56.0 | 660.0 | 715.0 |
1990 | 199.0 | 282.0 | 58.0 | 660.0 | 717.0 |
1991 | 199.1 | 282.0 | 58.0 | 660.0 | 717.0 |
1992 | 199.2 | 282.0 | 57.0 | 661.0 | 717.0 |
1993 | 199.3 | 282.0 | 57.0 | 661.0 | 717.0 |
1994 | 199.4 | 282.0 | 57.0 | 661.0 | 717.0 |
1995 | 199.5 | 282.0 | 57.0 | 661.0 | 717.0 |
1996 | 199.6 | 282.0 | 56.0 | 662.0 | 717.0 |
1997 | 199.7 | 282.0 | 55.0 | 663.0 | 717.0 |
1998 | 199.8 | 281.0 | 56.0 | 663.0 | 718.0 |
1999 | 199.9 | 281.0 | 55.0 | 664.0 | 718.0 |
2000 | 200.0 | 281.0 | 54.0 | 665.0 | 718.0 |
2001 rows × 5 columns
Pythran compiled version
As the above code only uses simple Python and Numpy times, it is straightforward to obtain compiled versions of the code using Pythran.
%%pythran -DUSE_XSIMD -march=native -O3
import numpy as np
import math
#pythran export sirp(float64 list, float64 list, float64)
def sirp(u,parms,t):
bet,gamm,iota,N,dt=parms
S,I,R,Y=u
lambd = bet*(I+iota)/N
ifrac = 1.0 - math.exp(-lambd*dt)
rfrac = 1.0 - math.exp(-gamm*dt)
infection = np.random.binomial(S,ifrac)
recovery = np.random.binomial(I,rfrac)
return [S-infection,I+infection-recovery,R+recovery,Y+infection]
#pythran export simulatep()
def simulatep():
parms = [0.1, 0.05, 0.01, 1000.0, 0.1]
tf = 200
tl = 2001
t = np.linspace(0,tf,tl)
S = np.zeros(tl)
I = np.zeros(tl)
R = np.zeros(tl)
Y = np.zeros(tl)
u = [999,1,0,0]
S[0],I[0],R[0],Y[0] = u
for j in range(1,tl):
u = sirp(u,parms,t[j])
S[j],I[j],R[j],Y[j] = u
return {'t':t,'S':S,'I':I,'R':R,'Y':Y}
%timeit simulatep()
434 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
This is around two orders of magnitude faster than the vanilla Python code.
sir_outp = pd.DataFrame(simulatep())
sir_outp
Y | R | I | S | t | |
---|---|---|---|---|---|
0 | 0.0 | 0.0 | 1.0 | 999.0 | 0.0 |
1 | 0.0 | 0.0 | 1.0 | 999.0 | 0.1 |
2 | 0.0 | 0.0 | 1.0 | 999.0 | 0.2 |
3 | 0.0 | 0.0 | 1.0 | 999.0 | 0.3 |
4 | 0.0 | 0.0 | 1.0 | 999.0 | 0.4 |
5 | 0.0 | 0.0 | 1.0 | 999.0 | 0.5 |
6 | 0.0 | 0.0 | 1.0 | 999.0 | 0.6 |
7 | 0.0 | 0.0 | 1.0 | 999.0 | 0.7 |
8 | 0.0 | 0.0 | 1.0 | 999.0 | 0.8 |
9 | 0.0 | 0.0 | 1.0 | 999.0 | 0.9 |
10 | 0.0 | 0.0 | 1.0 | 999.0 | 1.0 |
11 | 0.0 | 0.0 | 1.0 | 999.0 | 1.1 |
12 | 0.0 | 0.0 | 1.0 | 999.0 | 1.2 |
13 | 0.0 | 0.0 | 1.0 | 999.0 | 1.3 |
14 | 0.0 | 0.0 | 1.0 | 999.0 | 1.4 |
15 | 0.0 | 0.0 | 1.0 | 999.0 | 1.5 |
16 | 0.0 | 0.0 | 1.0 | 999.0 | 1.6 |
17 | 0.0 | 1.0 | 0.0 | 999.0 | 1.7 |
18 | 0.0 | 1.0 | 0.0 | 999.0 | 1.8 |
19 | 0.0 | 1.0 | 0.0 | 999.0 | 1.9 |
20 | 0.0 | 1.0 | 0.0 | 999.0 | 2.0 |
21 | 0.0 | 1.0 | 0.0 | 999.0 | 2.1 |
22 | 0.0 | 1.0 | 0.0 | 999.0 | 2.2 |
23 | 0.0 | 1.0 | 0.0 | 999.0 | 2.3 |
24 | 0.0 | 1.0 | 0.0 | 999.0 | 2.4 |
25 | 0.0 | 1.0 | 0.0 | 999.0 | 2.5 |
26 | 0.0 | 1.0 | 0.0 | 999.0 | 2.6 |
27 | 0.0 | 1.0 | 0.0 | 999.0 | 2.7 |
28 | 0.0 | 1.0 | 0.0 | 999.0 | 2.8 |
29 | 0.0 | 1.0 | 0.0 | 999.0 | 2.9 |
... | ... | ... | ... | ... | ... |
1971 | 0.0 | 1.0 | 0.0 | 999.0 | 197.1 |
1972 | 0.0 | 1.0 | 0.0 | 999.0 | 197.2 |
1973 | 0.0 | 1.0 | 0.0 | 999.0 | 197.3 |
1974 | 0.0 | 1.0 | 0.0 | 999.0 | 197.4 |
1975 | 0.0 | 1.0 | 0.0 | 999.0 | 197.5 |
1976 | 0.0 | 1.0 | 0.0 | 999.0 | 197.6 |
1977 | 0.0 | 1.0 | 0.0 | 999.0 | 197.7 |
1978 | 0.0 | 1.0 | 0.0 | 999.0 | 197.8 |
1979 | 0.0 | 1.0 | 0.0 | 999.0 | 197.9 |
1980 | 0.0 | 1.0 | 0.0 | 999.0 | 198.0 |
1981 | 0.0 | 1.0 | 0.0 | 999.0 | 198.1 |
1982 | 0.0 | 1.0 | 0.0 | 999.0 | 198.2 |
1983 | 0.0 | 1.0 | 0.0 | 999.0 | 198.3 |
1984 | 0.0 | 1.0 | 0.0 | 999.0 | 198.4 |
1985 | 0.0 | 1.0 | 0.0 | 999.0 | 198.5 |
1986 | 0.0 | 1.0 | 0.0 | 999.0 | 198.6 |
1987 | 0.0 | 1.0 | 0.0 | 999.0 | 198.7 |
1988 | 0.0 | 1.0 | 0.0 | 999.0 | 198.8 |
1989 | 0.0 | 1.0 | 0.0 | 999.0 | 198.9 |
1990 | 0.0 | 1.0 | 0.0 | 999.0 | 199.0 |
1991 | 0.0 | 1.0 | 0.0 | 999.0 | 199.1 |
1992 | 0.0 | 1.0 | 0.0 | 999.0 | 199.2 |
1993 | 0.0 | 1.0 | 0.0 | 999.0 | 199.3 |
1994 | 0.0 | 1.0 | 0.0 | 999.0 | 199.4 |
1995 | 0.0 | 1.0 | 0.0 | 999.0 | 199.5 |
1996 | 0.0 | 1.0 | 0.0 | 999.0 | 199.6 |
1997 | 0.0 | 1.0 | 0.0 | 999.0 | 199.7 |
1998 | 0.0 | 1.0 | 0.0 | 999.0 | 199.8 |
1999 | 0.0 | 1.0 | 0.0 | 999.0 | 199.9 |
2000 | 0.0 | 1.0 | 0.0 | 999.0 | 200.0 |
2001 rows × 5 columns
Visualisation
import matplotlib.pyplot as plt
plt.style.use("ggplot")
sline = plt.plot("t","S","",data=sir_out,color="red",linewidth=2)
iline = plt.plot("t","I","",data=sir_out,color="green",linewidth=2)
rline = plt.plot("t","R","",data=sir_out,color="blue",linewidth=2)
plt.xlabel("Time",fontweight="bold")
plt.ylabel("Number",fontweight="bold")
legend = plt.legend(title="Population",loc=5,bbox_to_anchor=(1.25,0.5))
frame = legend.get_frame()
frame.set_facecolor("white")
frame.set_linewidth(0)