别再只用欧氏距离了!用Python手写DTW算法,轻松搞定语音识别和股票走势对比
当我们需要比较两段语音的相似度,或者分析两只股票的价格走势是否相关时,第一反应往往是计算它们的欧氏距离。但实际操作中你会发现,这种简单粗暴的方法经常"翻车"——明明两个波形模式相似,计算结果却显示差异巨大。这就是动态时间规整(DTW)算法要解决的核心问题。
想象这样一个场景:两位歌手演唱同一首歌,专业歌手可能会在某些段落放慢节奏,而业余爱好者可能在某些高音部分突然加快。如果简单按时间点一一对应比较,结果必然失真。DTW算法的精妙之处在于,它允许时间轴弹性伸缩,自动寻找最佳匹配路径,就像一位经验丰富的音乐制作人能够听出两版演唱中对应的音符段落。
1. 为什么欧氏距离在时间序列分析中经常失效
欧氏距离计算的是两个向量在空间中的直线距离,对于时间序列而言,它严格要求比较的两个序列必须长度相同,且每个时间点严格对应。这种刚性匹配在面对现实世界的时间序列数据时,往往会带来三个致命问题:
- 长度敏感:要求比较的两个序列必须等长
- 相位敏感:对时间轴上的微小偏移极度敏感
- 节奏敏感:无法处理局部加速/减速的情况
让我们用Python生成两个简单的序列来演示这个问题:
import numpy as np import matplotlib.pyplot as plt # 生成示例序列 t = np.linspace(0, 10, 100) seq1 = np.sin(t) seq2 = np.sin(t * 0.9) # 时间轴压缩10% # 计算欧氏距离 euclidean_dist = np.sqrt(np.sum((seq1 - seq2)**2)) print(f"欧氏距离: {euclidean_dist:.2f}")输出结果:欧氏距离: 6.42
从视觉上看这两个正弦波非常相似,但欧氏距离却给出了相当大的数值。这就是因为它强制要求t=5时刻的点必须与t=5时刻的点比较,而实际上由于时间轴压缩,seq1的t=5点可能更应该与seq2的t≈5.56点比较。
2. DTW算法原理与实现步骤
DTW算法的核心思想是通过动态规划寻找两个序列之间的最优匹配路径,允许时间轴的非线性扭曲。这种扭曲需要满足三个基本约束:
- 边界条件:匹配必须从两端开始和结束
- 单调性:时间不能倒流
- 连续性:不能跳过任何数据点
2.1 DTW的数学表达
给定两个序列X=(x₁,x₂,...,xₙ)和Y=(y₁,y₂,...,yₘ),DTW构建一个n×m的累积距离矩阵D,其中:
D[i,j] = d(x_i,y_j) + min(D[i-1,j], D[i-1,j-1], D[i,j-1])其中d(x_i,y_j)通常是欧氏距离|x_i - y_j|。最终的DTW距离是D[n,m]的值。
2.2 Python实现基础DTW
让我们从零实现一个基础的DTW算法:
def dtw_distance(s1, s2): n, m = len(s1), len(s2) dtw_matrix = np.zeros((n+1, m+1)) # 初始化边界条件 dtw_matrix[0, 0] = 0 for i in range(1, n+1): dtw_matrix[i, 0] = np.inf for j in range(1, m+1): dtw_matrix[0, j] = np.inf # 填充DTW矩阵 for i in range(1, n+1): for j in range(1, m+1): cost = abs(s1[i-1] - s2[j-1]) dtw_matrix[i, j] = cost + min( dtw_matrix[i-1, j], # 插入 dtw_matrix[i, j-1], # 删除 dtw_matrix[i-1, j-1] # 匹配 ) return dtw_matrix[n, m]测试我们之前的正弦波序列:
dtw_dist = dtw_distance(seq1, seq2) print(f"DTW距离: {dtw_dist:.2f}")输出结果:DTW距离: 1.56
与之前的欧氏距离6.42相比,DTW给出的距离值小得多,更符合我们视觉上的相似度判断。
3. DTW在语音识别中的实战应用
语音识别是DTW算法的经典应用场景。不同人说同一个单词时,发音速度、音调都会有差异,DTW能够有效对齐这些变化。
3.1 语音特征提取
通常我们不直接处理原始音频波形,而是提取MFCC(梅尔频率倒谱系数)作为特征:
import librosa # 加载两个语音样本 y1, sr1 = librosa.load('speech1.wav') y2, sr2 = librosa.load('speech2.wav') # 提取MFCC特征 mfcc1 = librosa.feature.mfcc(y=y1, sr=sr1, n_mfcc=13) mfcc2 = librosa.feature.mfcc(y=y2, sr=sr2, n_mfcc=13) # 计算DTW距离 dtw_dist = dtw_distance(mfcc1.T, mfcc2.T)3.2 可视化对齐路径
理解DTW如何对齐两个序列非常重要,我们可以绘制warping path:
def plot_dtw_path(s1, s2): dtw_matrix = compute_dtw_matrix(s1, s2) path = compute_optimal_path(dtw_matrix) plt.figure(figsize=(10, 6)) plt.imshow(dtw_matrix, cmap='gray_r', origin='lower') plt.plot(path[:,1], path[:,0], 'r') # 绘制最优路径 plt.colorbar() plt.title("DTW矩阵与最优路径") plt.show()4. 股票走势分析中的DTW应用
在金融领域,DTW可以用来比较不同股票的价格走势相似度,或者同一股票在不同时间段的表现模式。
4.1 数据预处理
股票价格需要先进行归一化处理,因为我们关心的是走势形态而非绝对价格:
def normalize_series(series): return (series - series.mean()) / series.std() # 获取两只股票的历史收盘价 stock_a = get_history('AAPL')['close'].values stock_b = get_history('MSFT')['close'].values # 归一化 norm_a = normalize_series(stock_a) norm_b = normalize_series(stock_b)4.2 多维度DTW比较
除了价格序列,我们还可以结合交易量、技术指标等多维度数据进行综合比较:
def multi_dim_dtw(series1, series2, weights=None): if weights is None: weights = np.ones(len(series1)) total_dist = 0 for i, (s1, s2) in enumerate(zip(series1, series2)): total_dist += weights[i] * dtw_distance(s1, s2) return total_dist / sum(weights) # 比较价格和交易量序列 price_dist = dtw_distance(norm_a, norm_b) volume_dist = dtw_distance(volume_a, volume_b) combined_dist = multi_dim_dtw([norm_a, volume_a], [norm_b, volume_b], [0.7, 0.3])5. DTW的优化与变种算法
基础DTW算法有几个常见问题需要优化:计算复杂度高、可能产生不自然的扭曲路径、对噪声敏感等。
5.1 加速技巧:窗口限制
通过设置warping window限制搜索范围,可以大幅降低计算量:
def dtw_with_window(s1, s2, window_size=10): n, m = len(s1), len(s2) window = max(window_size, abs(n-m)) dtw_matrix = np.full((n+1, m+1), np.inf) dtw_matrix[0, 0] = 0 for i in range(1, n+1): for j in range(max(1, i-window), min(m+1, i+window)): cost = abs(s1[i-1] - s2[j-1]) dtw_matrix[i, j] = cost + min( dtw_matrix[i-1, j], dtw_matrix[i, j-1], dtw_matrix[i-1, j-1] ) return dtw_matrix[n, m]5.2 导数DTW(DDTW)
通过比较序列的导数而非原始值,可以更好捕捉形状特征:
def derivative(series): deriv = np.zeros_like(series) deriv[1:-1] = (series[2:] - series[:-2]) / 2 deriv[0] = series[1] - series[0] deriv[-1] = series[-1] - series[-2] return deriv def ddtw_distance(s1, s2): deriv1 = derivative(s1) deriv2 = derivative(s2) return dtw_distance(deriv1, deriv2)6. 实际应用中的注意事项
虽然DTW功能强大,但在实际应用中需要注意以下几点:
- 计算复杂度:原始DTW是O(nm),对于长序列需要考虑分段或降采样
- 边界效应:序列开始和结束部分的匹配可能不可靠
- 参数选择:窗口大小、距离度量等需要根据具体场景调整
- 结果解释:DTW距离没有上界,不同应用需要自行定义阈值
# 实用的DTW类实现 class EnhancedDTW: def __init__(self, window_size=None, dist_metric='euclidean'): self.window_size = window_size self.dist_metric = dist_metric def _distance(self, a, b): if self.dist_metric == 'euclidean': return abs(a - b) elif self.dist_metric == 'sqeuclidean': return (a - b)**2 # 可扩展其他距离度量 def compute(self, s1, s2): n, m = len(s1), len(s2) if self.window_size is None: window = max(n, m) else: window = self.window_size dtw_matrix = np.full((n+1, m+1), np.inf) dtw_matrix[0, 0] = 0 for i in range(1, n+1): for j in range(max(1, i-window), min(m+1, i+window)): cost = self._distance(s1[i-1], s2[j-1]) dtw_matrix[i, j] = cost + min( dtw_matrix[i-1, j], dtw_matrix[i, j-1], dtw_matrix[i-1, j-1] ) return dtw_matrix[n, m]