Forecast Combination#
import os
os.chdir("../../")
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from scripts.python.tsa.mtsmodel import *
from scripts.python.tsa.ts_eval import *
import seaborn as sns
sns.set_style("whitegrid")
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
As Timmermann (2004) summarizes Stock and Watson (1998)’s relative performance weights. Let \(MSE_{t+h,t,i} = (1/v)\sum_{\tau=t-v}^{t} e^{2}_{\tau,\tau−h,i}\) be the \(i\)th forecasting model’s MSE at time \(t\), computed over a window of the previous \(v\) periods. Then
\[ \hat{y}_{t+h,t} = \sum_{i=1}^{N} \hat{\omega}_{t+h,t,i} \hat{y}_{t+h,t,i}, \text{ where } \hat{\omega}_{t+h,t,i} = \frac{(1/MSE_{t+h,t,i})}{\sum_{j=1}^{N} (1/MSE_{t+h,t,j})}\]
Below are the functions to calculate the relative performance weight where \(i \in \{sarimax, lf, var\}\).
Show code cell source
def calculate_mse(predictions_df: pd.DataFrame, method: str) -> pd.Series:
total = predictions_df["total"]
prediction = predictions_df[method]
mse = np.square(total - prediction).cumsum() / (predictions_df.index + 1)
return mse
def calculate_rpw(predictions_df: pd.DataFrame, methods: list) -> pd.Series:
mse_dict = {method: calculate_mse(predictions_df, method)
for method in methods}
denominator = sum(1 / mse_dict[method] for method in methods)
rpw_dict = {}
for method in methods:
numerator = 1 / mse_dict[method]
omega = numerator / denominator
rpw_dict[method] = omega
return pd.Series(rpw_dict)
def get_rpw(pred_df: pd.DataFrame,
methods: list = ["sarimax", "var", "lf"]) -> pd.Series:
predictions_df = pred_df.copy()
rpw_series = calculate_rpw(predictions_df, methods)
combo_cols = []
for method in methods:
weight = str(method) + "_weight"
predictions_df[weight] = predictions_df[method] * rpw_series[method]
combo_cols.append(weight)
rpw_pred = predictions_df[combo_cols].sum(axis=1)
rpw = pd.DataFrame(rpw_series.to_dict())
rpw.columns = ["rpw_" + str(col) for col in rpw.columns]
return rpw_pred, rpw
def get_constrained_ls(y: pd.DataFrame,
X: pd.DataFrame) -> np.array:
from scipy.optimize import nnls, minimize
A, b = np.array(X), np.array(y)
x0, norm = nnls(A, b)
def f(x, A, b):
return np.linalg.norm(A.dot(x) - b)
mini = minimize(f, x0, args=(A, b), method='SLSQP',
bounds=[[0, np.inf]], ## Only set the lb
constraints={'type': 'eq', 'fun': lambda x: np.sum(x)-1})
estimates = mini.x
pred = A.dot(estimates)
return estimates, pred
for country in ["palau", "samoa", "tonga", "solomon_islands", "vanuatu"]:
folderpath = os.getcwd() + "/data/tourism/" + str(country) + "/model/"
mappings = [("var", "pred_total"),
("sarimax", "train_pred"),
("lf", "pred_mean")]
country_pred = pd.DataFrame()
for mapping in mappings:
model, column = mapping
filepath = folderpath + str(model) + "_" + str(country) + ".csv"
pred_df = pd.read_csv(filepath).drop("Unnamed: 0", axis=1)
pred_df["date"] = pd.to_datetime(pred_df["date"])
model_col = (pred_df[["date", "total", column]]
.rename({column: model}, axis=1))
if country_pred.empty:
country_pred = model_col
else:
country_pred = country_pred.merge(
model_col, how="left", on=["date", "total"]).fillna(0)
# Mean
country_pred["mean_ensemble"] = (
country_pred[["sarimax", "var", "lf"]].mean(axis=1))
# Relative Performance Weights
country_pred["rpw"], weights = get_rpw(country_pred)
# OLS (regularized) without intercept
ols = smf.ols("total~sarimax+var+lf - 1", data=country_pred)
ols_res = ols.fit()
ols_lasso = ols.fit_regularized()
country_pred["ols"] = ols_res.fittedvalues
country_pred["ols_lasso"] = ols_lasso.fittedvalues
methods = ["sarimax", "var", "lf"]
constrained_ls, cls_pred = get_constrained_ls(y=country_pred["total"],
X=country_pred[methods])
country_pred["cls"] = cls_pred
for method, cls in zip(methods, constrained_ls):
weights["ols_"+str(method)] = ols_res.params[method]
weights["ols_lasso_"+str(method)] = ols_lasso.params[method]
weights["cls_"+str(method)] = cls
weights["date"] = country_pred["date"]
# Sort columns
cols = weights.columns.tolist()
cols = cols[-1:] + cols[:-1]
weights = weights[cols]
# Save Combination weights
weights.to_csv(folderpath+"forecast_combo_weights.csv",
encoding="utf-8")
# Save Forecast Prediction
country_pred.to_csv(folderpath+"forecast_combo.csv",
encoding="utf-8")
evals = pd.DataFrame()
for col in ["sarimax", "var", "lf", "mean_ensemble", "rpw", "ols", "cls", "ols_lasso"]:
mod_eval = pd.DataFrame(calculate_evaluation(country_pred["total"], country_pred[col]),
index=[col])
evals = pd.concat([evals, mod_eval], axis=0)
evals.columns.name = str(country)
evals = evals.style.apply(
lambda x: ['background-color: yellow' if v == x.min() else '' for v in x])
display(evals)
palau | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 1586348.581999 | 1259.503308 | 701.765452 | 53.857541 |
var | 1127006.129408 | 1061.605449 | 554.935555 | 38.243616 |
lf | 522059.184389 | 722.536632 | 403.159217 | 41.454946 |
mean_ensemble | 538679.103245 | 733.947616 | 415.191645 | 34.057742 |
rpw | 416042.585625 | 645.013632 | 355.022438 | 33.879718 |
ols | 459520.198387 | 677.879192 | 374.512959 | 34.509414 |
cls | 462504.925790 | 680.077147 | 377.897454 | 34.589740 |
ols_lasso | 485185.924697 | 696.552887 | 387.646868 | 33.639519 |
samoa | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 8076301.733757 | 2841.883483 | 1410.434675 | 141.993878 |
var | 10290887.267182 | 3207.941282 | 1757.557409 | 141.614714 |
lf | 2107650.783058 | 1451.775046 | 763.990597 | 131.108537 |
mean_ensemble | 3794368.839523 | 1947.913971 | 1093.282593 | 135.822954 |
rpw | 2242890.336299 | 1497.628237 | 746.332632 | 131.773854 |
ols | 1964126.186335 | 1401.472863 | 784.266056 | 131.370686 |
cls | 2009637.625508 | 1417.616882 | 763.104201 | 131.248177 |
ols_lasso | 2003163.146708 | 1415.331462 | 819.357748 | 131.860632 |
tonga | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 677589.950307 | 823.158521 | 473.356102 | 81.354328 |
var | 547847.120126 | 740.166954 | 354.330235 | 61.940427 |
lf | 203622.901860 | 451.245944 | 258.523902 | 87.345914 |
mean_ensemble | 194568.858596 | 441.099602 | 239.402169 | 78.543566 |
rpw | 155098.313391 | 393.825232 | 216.180497 | 83.251701 |
ols | 110277.230689 | 332.080157 | 198.679691 | 83.294941 |
cls | 167576.289342 | 409.360830 | 225.881663 | 84.660888 |
ols_lasso | 125878.110314 | 354.793053 | 197.297385 | 77.645737 |
solomon_islands | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 46105.079450 | 214.720934 | 147.806140 | 31.941491 |
var | 65021.669854 | 254.993470 | 154.338247 | 17.694166 |
lf | 48753.426027 | 220.801780 | 150.232442 | 26.614403 |
mean_ensemble | 31432.381706 | 177.291798 | 126.564826 | 22.784182 |
rpw | 26373.763079 | 162.400009 | 118.723610 | 22.831037 |
ols | 31219.515216 | 176.690450 | 128.227540 | 23.614795 |
cls | 31242.949442 | 176.756752 | 128.315830 | 23.459240 |
ols_lasso | 35808.420435 | 189.231130 | 135.363359 | 27.983984 |
vanuatu | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 669250.974064 | 818.077609 | 372.363525 | 133.369731 |
var | 1600258.286965 | 1265.013157 | 554.204122 | 134.493067 |
lf | 642113.393395 | 801.319782 | 497.354972 | 128.023752 |
mean_ensemble | 496754.494910 | 704.808126 | 351.729479 | 131.547897 |
rpw | 333825.379320 | 577.776236 | 289.982721 | 130.586670 |
ols | 394337.435274 | 627.962925 | 342.287171 | 131.085132 |
cls | 413561.135366 | 643.087191 | 354.292010 | 131.225479 |
ols_lasso | 419935.486543 | 648.024295 | 330.751929 | 131.571070 |