1. MSGNet 是什么?
MSGNet 来自论文MSGNet: Learning Multi-Scale Inter-Series Correlations for Multivariate Time Series Forecasting,用于多变量时间序列预测。它关注的不只是单条序列内部的时间依赖,也关注多个变量之间的关系,并且强调:变量间关系会随时间尺度变化。例如电力负荷、天气、交通、金融资产之间的相关性,短周期和长周期下可能完全不同。论文明确指出,MSGNet 用频域分析提取显著周期,并用自适应图卷积捕捉不同时间尺度下的序列间关系。(arXiv)
官方仓库也把 MSGNet 的核心模块概括为三个部分:FFT 多尺度识别模块、用于序列间关系学习的自适应图卷积模块、用于序列内时间关系学习的多头注意力模块。(GitHub)
2. MSGNet 解决的问题
传统多变量预测模型常见有两类:
| 类型 | 建模重点 | 问题 |
|---|---|---|
| RNN / TCN / Transformer | 单变量或整体时间依赖 | 变量间关系建模弱 |
| GNN 时间序列模型 | 变量间图关系 | 常用固定图,忽略尺度变化 |
| TimesNet 类模型 | 多周期、多尺度时间模式 | 对变量间相关性的显式建模不足 |
MSGNet 的切入点是:
同一组变量,在不同时间尺度下可能对应不同的图结构。
比如:
小时尺度:温度和电力负荷可能强相关。
日尺度:工作日/周末模式更重要。
周尺度:节律、季节性、宏观因素更明显。
因此 MSGNet 不用一张固定图,而是学习多张图:
[
A_1, A_2, \dots, A_k
]
每张图对应一个尺度。
3. MSGNet 的整体流程
MSGNet 的一个核心单元通常称为ScaleGraph Block。论文描述每个 ScaleGraph Block 大致包含四步:识别输入序列尺度、用自适应图卷积学习该尺度下的序列间关系、用多头注意力捕捉尺度内时间关系、最后用 SoftMax 对不同尺度表示做自适应融合。(arXiv)
可以记成这条链:
输入序列 X ↓ FFT 找 Top-k 周期尺度 ↓ 按不同周期 reshape ↓ 每个尺度学习一张自适应邻接矩阵 ↓ MixHop / 图卷积捕捉变量间依赖 ↓ Multi-Head Attention 捕捉尺度内时间依赖 ↓ 按 FFT 幅值加权融合 ↓ 预测未来序列4. 关键模块详解
4.1 FFT 多尺度识别
MSGNet 用 FFT 来发现输入序列中的显著周期。论文中说明,FFT 用于检测主要周期性,并把这些周期作为时间尺度来源。(arXiv)
对于输入:
[
X \in \mathbb{R}^{B \times T \times N}
]
其中:
(B):batch size
(T):历史窗口长度
(N):变量个数
做 FFT:
[
\mathcal{F}(X) = FFT(X)
]
然后取频率幅值,选出 Top-k 个频率:
[
f_1, f_2, \dots, f_k = TopK(|FFT(X)|)
]
周期尺度为:
[
p_i = \frac{T}{f_i}
]
注意:代码里常写成整数除法:
period = x.shape[1] // top_frequencies官方实现里也有类似的FFT_for_Period函数,使用torch.fft.rfft找 Top-k 频率,并返回对应周期与尺度权重。(GitHub)
4.2 多尺度重排
假设得到周期 (p_i),就把序列 reshape 成:
[
[B, T, N, D] \rightarrow [B, \frac{T}{p_i}, p_i, N, D]
]
含义是:
(\frac{T}{p_i}):有多少个周期片段
(p_i):每个周期内部的时间步
(N):变量节点
(D):隐藏维度
这个 reshape 很像把一卷时间胶片剪成若干段周期小胶片,模型可以分别观察“周期内”和“周期间”的纹理。
4.3 每个尺度学习一张自适应图
MSGNet 对每个尺度学习一个邻接矩阵。论文中描述为用两个可训练参数相乘,再经过 ReLU 和 SoftMax 得到自适应邻接矩阵。(arXiv)
常见写法:
[
A_i = SoftMax(ReLU(E_1^i E_2^i))
]
其中:
(A_i \in \mathbb{R}^{N \times N})
(N):变量数
(i):第 (i) 个尺度
(E_1^i, E_2^i):可学习节点嵌入参数
直观理解:
每个变量 = 图上的一个节点 变量之间的依赖 = 图上的边 每个时间尺度 = 一张不同的图4.4 MixHop 图卷积捕捉变量间关系
MSGNet 使用 MixHop 图卷积来捕捉多跳变量依赖。论文中提到,得到某个尺度下的邻接矩阵后,使用 MixHop graph convolution 捕捉序列间相关性。(arXiv)
普通一阶图卷积只看直接邻居:
[
A X
]
MixHop 会看多跳邻居:
[
X, AX, A^2X, \dots, A^PX
]
然后拼接或融合:
[
H_{out} = MLP([X, AX, A^2X, \dots, A^PX])
]
这样可以捕捉:
一阶关系:A 直接影响 B
二阶关系:A 通过 B 影响 C
多跳关系:复杂变量网络中的间接传导
4.5 多头注意力捕捉尺度内时间依赖
图卷积负责“变量之间说了什么”,多头注意力负责“这个周期内部的时间节奏怎么跳”。
论文中说明,MSGNet 在每个时间尺度下使用 MHA 捕捉 intra-series correlations,即序列内部时间关系。由于尺度变换把长时间跨度折叠成周期长度,注意力可以更容易处理局部周期结构。(arXiv)
4.6 多尺度融合
最后,MSGNet 根据 FFT 得到的幅值,对不同尺度的输出做 SoftMax 加权融合。论文把这种机制解释为一种 Mixture-of-Experts 思路:不同尺度像不同专家,FFT 幅值越大,说明这个尺度越重要。(arXiv)
[
Y = \sum_{i=1}^{k} \alpha_i Y_i
]
其中:
[
\alpha = SoftMax(w)
]
5. PyTorch 教学版 MSGNet
下面代码是简化教学版,重点复现 MSGNet 的核心思想:
FFT 识别周期
每个尺度独立图卷积
MixHop 多跳传播
周期内 Multi-Head Attention
多尺度加权融合
输出未来预测
它不是官方代码逐行复刻,而是更适合学习、讲解、二次开发的版本。
5.1 完整模型代码
import torch import torch.nn as nn import torch.nn.functional as F def fft_for_period(x, top_k=3): """ 使用 FFT 找出输入序列的 Top-k 周期。 参数: x: [B, T, N, D] B = batch size T = 输入长度 N = 变量数 D = 隐藏维度 top_k: 选取前 k 个主要频率 返回: periods: List[int], 周期长度 weights: [B, top_k], 每个样本在 Top-k 频率上的幅值 """ B, T, N, D = x.shape # 沿时间维做实数 FFT xf = torch.fft.rfft(x, dim=1) # [B, F, N, D] # 全局平均幅值,用于选频率 freq_amplitude = xf.abs().mean(dim=(0, 2, 3)) # [F] # 去掉直流分量,频率 0 表示整体均值,不代表周期 freq_amplitude[0] = 0 _, top_indices = torch.topk(freq_amplitude, top_k) # 频率索引 f 对应周期 T // f periods = (T // top_indices).detach().cpu().tolist() periods = [max(1, int(p)) for p in periods] # 每个样本的尺度权重 sample_weights = xf.abs().mean(dim=(2, 3))[:, top_indices] # [B, top_k] return periods, sample_weights class AdaptiveMixHopGraphConv(nn.Module): """ 自适应 MixHop 图卷积。 每个尺度都有自己的 E1, E2,用来学习邻接矩阵 A。 然后进行多跳传播: X, AX, A^2X, ... """ def __init__(self, num_nodes, d_model, node_dim=16, gcn_depth=2, dropout=0.1, alpha=0.2): super().__init__() self.num_nodes = num_nodes self.d_model = d_model self.gcn_depth = gcn_depth self.alpha = alpha self.dropout = nn.Dropout(dropout) # 自适应图参数 self.nodevec1 = nn.Parameter(torch.randn(num_nodes, node_dim)) self.nodevec2 = nn.Parameter(torch.randn(node_dim, num_nodes)) # MixHop 后把多跳特征融合回 d_model self.proj = nn.Linear((gcn_depth + 1) * d_model, d_model) self.norm = nn.LayerNorm(d_model) def learned_adj(self): """ 学习邻接矩阵 A: [N, N] """ adj = torch.mm(self.nodevec1, self.nodevec2) adj = F.relu(adj) adj = F.softmax(adj, dim=-1) return adj def graph_propagate(self, x, adj): """ 图传播。 参数: x: [B, T, N, D] adj: [N, N] 返回: out: [B, T, N, D] """ # adj[i, j] 表示节点 i 从节点 j 聚合信息 return torch.einsum("ij,btjd->btid", adj, x) def forward(self, x): """ 参数: x: [B, T, N, D] 返回: out: [B, T, N, D] """ adj = self.learned_adj() # 加自环并归一化 eye = torch.eye(self.num_nodes, device=x.device) adj = adj + eye adj = adj / adj.sum(dim=-1, keepdim=True).clamp(min=1e-6) h = x outputs = [h] for _ in range(self.gcn_depth): # APPNP 风格的残差传播 h = self.alpha * x + (1 - self.alpha) * self.graph_propagate(h, adj) outputs.append(h) # 拼接多跳特征 out = torch.cat(outputs, dim=-1) # [B, T, N, (gcn_depth+1)*D] out = self.proj(out) out = self.dropout(out) return self.norm(x + out) class ScaleGraphBlock(nn.Module): """ 一个 MSGNet 风格的 ScaleGraph Block。 步骤: 1. FFT 找 Top-k 周期 2. 每个尺度做自适应 MixHop 图卷积 3. reshape 后做周期内 Multi-Head Attention 4. 多尺度加权融合 """ def __init__( self, seq_len, num_nodes, d_model, top_k=3, node_dim=16, gcn_depth=2, n_heads=4, dropout=0.1 ): super().__init__() self.seq_len = seq_len self.num_nodes = num_nodes self.d_model = d_model self.top_k = top_k self.graph_convs = nn.ModuleList([ AdaptiveMixHopGraphConv( num_nodes=num_nodes, d_model=d_model, node_dim=node_dim, gcn_depth=gcn_depth, dropout=dropout ) for _ in range(top_k) ]) self.attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True ) self.ffn = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout) ) self.norm = nn.LayerNorm(d_model) def temporal_attention_by_period(self, x, period): """ 在某个周期尺度下做时间注意力。 参数: x: [B, T, N, D] period: 周期长度 p 返回: out: [B, T, N, D] """ B, T, N, D = x.shape # padding 到 period 的整数倍 if T % period != 0: new_T = ((T // period) + 1) * period pad_len = new_T - T padding = torch.zeros(B, pad_len, N, D, device=x.device, dtype=x.dtype) x_pad = torch.cat([x, padding], dim=1) else: new_T = T x_pad = x num_segments = new_T // period # [B, new_T, N, D] -> [B, S, p, N, D] x_period = x_pad.reshape(B, num_segments, period, N, D) # 对每个变量、每个周期片段,在 period 维度上做 attention # [B, S, p, N, D] -> [B*S*N, p, D] x_attn = x_period.permute(0, 1, 3, 2, 4).reshape(B * num_segments * N, period, D) attn_out, _ = self.attn(x_attn, x_attn, x_attn) x_attn = x_attn + attn_out x_attn = x_attn + self.ffn(x_attn) # reshape 回 [B, new_T, N, D] x_back = x_attn.reshape(B, num_segments, N, period, D) x_back = x_back.permute(0, 1, 3, 2, 4).reshape(B, new_T, N, D) # 去掉 padding return x_back[:, :T] def forward(self, x): """ 参数: x: [B, T, N, D] 返回: out: [B, T, N, D] """ B, T, N, D = x.shape periods, scale_weights = fft_for_period(x, self.top_k) scale_outputs = [] for i, period in enumerate(periods): # 1. 该尺度下的图卷积,捕捉变量间关系 out = self.graph_convs[i](x) # 2. 该尺度下的周期内注意力,捕捉时间关系 out = self.temporal_attention_by_period(out, period) scale_outputs.append(out) # [B, T, N, D, K] stacked = torch.stack(scale_outputs, dim=-1) # [B, K] -> [B, 1, 1, 1, K] scale_weights = F.softmax(scale_weights, dim=-1) scale_weights = scale_weights[:, None, None, None, :] # 多尺度融合 out = (stacked * scale_weights).sum(dim=-1) # 残差连接 out = self.norm(x + out) return out class MiniMSGNet(nn.Module): """ 教学版 MSGNet,用于多变量时间序列预测。 输入: x: [B, seq_len, num_nodes] 输出: y_hat: [B, pred_len, num_nodes] """ def __init__( self, seq_len, pred_len, num_nodes, d_model=64, e_layers=2, top_k=3, node_dim=16, gcn_depth=2, n_heads=4, dropout=0.1 ): super().__init__() self.seq_len = seq_len self.pred_len = pred_len self.num_nodes = num_nodes self.d_model = d_model # 每个变量的标量值映射到隐藏空间 self.value_embedding = nn.Linear(1, d_model) self.blocks = nn.ModuleList([ ScaleGraphBlock( seq_len=seq_len, num_nodes=num_nodes, d_model=d_model, top_k=top_k, node_dim=node_dim, gcn_depth=gcn_depth, n_heads=n_heads, dropout=dropout ) for _ in range(e_layers) ]) # 隐藏维度映射回单变量值 self.output_projection = nn.Linear(d_model, 1) # 沿时间维从 seq_len 预测 pred_len self.temporal_projection = nn.Linear(seq_len, pred_len) def forward(self, x): """ 参数: x: [B, T, N] 返回: y_hat: [B, pred_len, N] """ # Non-stationary Transformer 风格标准化 mean = x.mean(dim=1, keepdim=True).detach() std = x.std(dim=1, keepdim=True, unbiased=False).detach() x_norm = (x - mean) / (std + 1e-5) # [B, T, N] -> [B, T, N, 1] -> [B, T, N, D] h = self.value_embedding(x_norm.unsqueeze(-1)) for block in self.blocks: h = block(h) # [B, T, N, D] -> [B, T, N] h = self.output_projection(h).squeeze(-1) # [B, T, N] -> [B, N, T] -> [B, N, pred_len] y = h.permute(0, 2, 1) y = self.temporal_projection(y) # [B, N, pred_len] -> [B, pred_len, N] y = y.permute(0, 2, 1) # 反标准化 y = y * (std[:, 0, :].unsqueeze(1) + 1e-5) + mean[:, 0, :].unsqueeze(1) return y6. 最小训练示例
下面构造一个合成多变量时间序列,里面包含不同周期和变量依赖关系,然后训练上面的MiniMSGNet。
import math import torch from torch.utils.data import Dataset, DataLoader class SlidingWindowDataset(Dataset): def __init__(self, data, seq_len, pred_len): """ data: [T_total, N] """ self.data = torch.tensor(data, dtype=torch.float32) self.seq_len = seq_len self.pred_len = pred_len def __len__(self): return len(self.data) - self.seq_len - self.pred_len + 1 def __getitem__(self, idx): x = self.data[idx: idx + self.seq_len] y = self.data[idx + self.seq_len: idx + self.seq_len + self.pred_len] return x, y def make_synthetic_data(total_len=2000, num_nodes=5): """ 构造一个多变量时间序列: 节点 0: 日周期 + 噪声 节点 1: 依赖节点 0,带滞后 节点 2: 长周期 节点 3: 节点 0 和节点 2 的混合 节点 4: 较强噪声变量 """ t = torch.arange(total_len).float() x0 = torch.sin(2 * math.pi * t / 24) x1 = 0.7 * torch.roll(x0, shifts=3) + 0.2 * torch.sin(2 * math.pi * t / 12) x2 = torch.sin(2 * math.pi * t / 96) x3 = 0.5 * x0 + 0.5 * x2 x4 = 0.3 * torch.sin(2 * math.pi * t / 48) data = torch.stack([x0, x1, x2, x3, x4], dim=-1) data += 0.05 * torch.randn_like(data) return data.numpy() # 参数 seq_len = 96 pred_len = 24 num_nodes = 5 data = make_synthetic_data(total_len=2500, num_nodes=num_nodes) train_data = data[:2000] test_data = data[2000 - seq_len - pred_len:] train_dataset = SlidingWindowDataset(train_data, seq_len, pred_len) test_dataset = SlidingWindowDataset(test_data, seq_len, pred_len) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MiniMSGNet( seq_len=seq_len, pred_len=pred_len, num_nodes=num_nodes, d_model=64, e_layers=2, top_k=3, node_dim=16, gcn_depth=2, n_heads=4, dropout=0.1 ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = torch.nn.MSELoss() # 训练 for epoch in range(5): model.train() total_loss = 0.0 for x, y in train_loader: x = x.to(device) # [B, seq_len, N] y = y.to(device) # [B, pred_len, N] optimizer.zero_grad() pred = model(x) loss = criterion(pred, y) loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(train_loader) # 测试 model.eval() test_loss = 0.0 with torch.no_grad(): for x, y in test_loader: x = x.to(device) y = y.to(device) pred = model(x) loss = criterion(pred, y) test_loss += loss.item() test_loss /= len(test_loader) print(f"Epoch {epoch + 1:02d} | Train Loss: {avg_loss:.6f} | Test Loss: {test_loss:.6f}")7. 单次预测示例
model.eval() x, y_true = test_dataset[0] x = x.unsqueeze(0).to(device) # [1, seq_len, N] with torch.no_grad(): y_pred = model(x) print("预测形状:", y_pred.shape) print("真实形状:", y_true.shape) # y_pred: [1, pred_len, num_nodes]8. 如何查看模型学到的图结构?
因为每个尺度都有独立的AdaptiveMixHopGraphConv,所以可以把每个尺度的邻接矩阵取出来。
block = model.blocks[0] for i, gconv in enumerate(block.graph_convs): adj = gconv.learned_adj().detach().cpu() print(f"Scale {i} adjacency matrix:") print(adj)输出类似:
Scale 0 adjacency matrix: tensor([[0.21, 0.35, 0.10, 0.25, 0.09], [0.31, 0.20, 0.08, 0.28, 0.13], ...])这可以解释为:
第 0 个尺度下,变量之间的依赖图。
第 1 个尺度下,变量关系可能不同。
第 2 个尺度下,模型可能学到另一个周期模式。
这正是 MSGNet 的精髓:不同周期,不同图谱。
9. 与官方 MSGNet 的关系
上面的代码是“教学简化版”。官方实现中,ScaleGraphBlock调用FFT_for_Period获取周期与尺度权重,然后对每个尺度使用GraphBlock,再 reshape 做注意力,最后按尺度权重融合。(GitHub)
官方MSGBlock.py里可以看到:
GraphBlock使用nodevec1和nodevec2学习自适应邻接矩阵。mixprop执行多跳图传播。Attention_Block实现多头自注意力结构。(GitHub)
官方仓库 README 也说明可以通过脚本训练评估 MSGNet,例如sh ./scripts/ETTh1.sh,并支持把新模型文件加入./models目录进行扩展。(GitHub)
10. 适合论文写作的描述
可以这样写:
MSGNet 通过频域分析自动识别多变量时间序列中的显著周期模式,并将输入序列映射到多个时间尺度空间。在每个尺度下,模型构建独立的自适应图结构,以刻画该尺度中特定的变量间依赖关系;随后利用 MixHop 图卷积捕捉多跳序列间相关性,并结合多头自注意力机制建模尺度内时间依赖。最后,模型根据 FFT 幅值对不同尺度表示进行自适应加权融合,从而实现多尺度序列内与序列间关系的联合建模。
11. 实战调参建议
| 参数 | 含义 | 建议 |
|---|---|---|
seq_len | 输入历史窗口 | 常用 96、192、336 |
pred_len | 预测长度 | 24、48、96、192 等 |
top_k | 选几个主要周期 | 3 或 5 常见 |
gcn_depth | MixHop 传播深度 | 2 通常够用 |
node_dim | 学习邻接矩阵的节点嵌入维度 | 8、16、32 |
d_model | 隐藏维度 | 32、64、128 |
e_layers | ScaleGraph Block 层数 | 1 到 3 |
n_heads | 注意力头数 | 4 或 8 |
一个稳妥起步配置:
model = MiniMSGNet( seq_len=96, pred_len=24, num_nodes=变量数, d_model=64, e_layers=2, top_k=3, node_dim=16, gcn_depth=2, n_heads=4, dropout=0.1 )12. 一句话总结
MSGNet 的核心是:用 FFT 找多尺度周期,用每个尺度专属的自适应图学习变量间关系,再用注意力捕捉周期内时间依赖,最后把不同尺度的信息融合起来做预测。