背景知识:模型漂移与再训练策略
金融市场是非平稳的。今天有效的模型,明天可能失效。检测漂移、及时再训练,是生产级系统的必备能力。
一、什么是模型漂移?
模型漂移(Model Drift)指模型在部署后,预测性能随时间逐渐下降的现象。
1.1 漂移的两种类型
| 类型 | 定义 | 金融场景示例 |
|---|---|---|
| 数据漂移(Data Drift) | 输入特征的分布发生变化 | 波动率从 15% 升至 40%(COVID危机) |
| 概念漂移(Concept Drift) | 特征与目标的关系发生变化 | 动量因子从有效变为无效(regime切换) |
1.2 金融市场漂移的根本原因
为什么金融模型必然漂移?
1. 市场参与者结构变化
- 散户涌入 → 动量效应增强
- 量化基金增多 → Alpha衰减
2. 宏观环境变化
- 利率周期切换(QE → 加息)
- 经济周期切换(扩张 → 衰退)
3. 监管政策变化
- 做空限制 → 价格发现机制改变
- 高频交易限制 → 微观结构变化
4. 技术与信息变化
- 新数据源出现 → 旧因子被抢跑
- AI普及 → 策略同质化
二、漂移检测方法
2.1 性能监控法
最直接的方法:监控滚动窗口的策略表现。
import numpy as np
class PerformanceMonitor:
"""性能漂移监控器"""
def __init__(self, window: int = 30, sharpe_threshold: float = 0.5):
self.window = window # 滚动窗口(天)
self.sharpe_threshold = sharpe_threshold
self.returns = []
def update(self, daily_return: float) -> dict:
"""更新并检查漂移"""
self.returns.append(daily_return)
if len(self.returns) < self.window:
return {'status': 'warming_up'}
# 计算滚动夏普
recent = self.returns[-self.window:]
rolling_sharpe = np.mean(recent) / np.std(recent) * np.sqrt(252)
# 检测漂移
is_drifting = rolling_sharpe < self.sharpe_threshold
return {
'rolling_sharpe': rolling_sharpe,
'is_drifting': is_drifting,
'alert': 'DRIFT_DETECTED' if is_drifting else 'OK'
}
阈值设置建议:
| 指标 | 警告阈值 | 严重阈值 | 触发动作 |
|---|---|---|---|
| 滚动夏普 | < 0.5 | < 0 | 触发再训练 |
| 滚动胜率 | < 45% | < 40% | 检查信号质量 |
| 滚动收益 | < -5% | < -10% | 降低仓位 |
2.2 统计检验法
Kolmogorov-Smirnov 检验(K-S检验)
检测特征分布是否发生显著变化。
import numpy as np
from scipy.stats import ks_2samp
def detect_data_drift(
training_data: np.ndarray,
recent_data: np.ndarray,
significance: float = 0.05
) -> dict:
"""
K-S检验检测数据漂移
原理:比较两个样本是否来自同一分布
H0: 两个样本来自同一分布
如果 p < significance,拒绝H0,认为发生漂移
"""
statistic, p_value = ks_2samp(training_data, recent_data)
return {
'ks_statistic': statistic, # D值,越大表示分布差异越大
'p_value': p_value,
'is_drifting': p_value < significance,
'interpretation': 'DRIFT' if p_value < significance else 'STABLE'
}
# 使用示例
training_returns = returns['2020-01':'2022-12']
recent_returns = returns['2024-01':'2024-03']
result = detect_data_drift(training_returns, recent_returns)
print(f"K-S统计量: {result['ks_statistic']:.4f}")
print(f"P值: {result['p_value']:.4f}")
print(f"状态: {result['interpretation']}")
卡方检验(Chi-Square Test)
适用于分类特征的漂移检测。
from scipy.stats import chi2_contingency
def detect_categorical_drift(
training_counts: dict,
recent_counts: dict,
significance: float = 0.05
) -> dict:
"""
卡方检验检测分类特征漂移
示例:检测市场状态标签分布是否变化
training_counts = {'bull': 120, 'bear': 80, 'sideways': 50}
recent_counts = {'bull': 10, 'bear': 35, 'sideways': 5}
"""
# 构建列联表
categories = set(training_counts.keys()) | set(recent_counts.keys())
train_freq = [training_counts.get(c, 0) for c in categories]
recent_freq = [recent_counts.get(c, 0) for c in categories]
contingency_table = [train_freq, recent_freq]
chi2, p_value, dof, expected = chi2_contingency(contingency_table)
return {
'chi2_statistic': chi2,
'p_value': p_value,
'degrees_of_freedom': dof,
'is_drifting': p_value < significance
}
2.3 CUSUM 控制图法
累积和控制图:检测预测误差的持续偏移。
class CUSUMDetector:
"""
CUSUM(累积和)漂移检测器
原理:
- 累积预测误差的偏离
- 如果误差随机,累积和应在0附近波动
- 如果存在系统性偏差,累积和会持续偏离
"""
def __init__(self, threshold: float = 5.0, drift: float = 0.5):
"""
参数:
- threshold: 触发告警的阈值
- drift: 允许的漂移量(敏感度控制)
"""
self.threshold = threshold
self.drift = drift
self.reset()
def reset(self):
self.s_pos = 0 # 正向累积和
self.s_neg = 0 # 负向累积和
self.history = []
def update(self, error: float) -> dict:
"""
更新CUSUM值
参数:
- error: 预测误差(预测值 - 实际值)
返回:
- 漂移检测结果
"""
# 标准化误差
normalized_error = error
# 更新累积和
self.s_pos = max(0, self.s_pos + normalized_error - self.drift)
self.s_neg = max(0, self.s_neg - normalized_error - self.drift)
self.history.append({
's_pos': self.s_pos,
's_neg': self.s_neg,
'error': error
})
# 检测漂移
drift_up = self.s_pos > self.threshold
drift_down = self.s_neg > self.threshold
if drift_up or drift_down:
direction = 'UP' if drift_up else 'DOWN'
return {
'is_drifting': True,
'direction': direction,
'cusum_value': self.s_pos if drift_up else self.s_neg,
'action': 'RETRAIN_RECOMMENDED'
}
return {
'is_drifting': False,
'cusum_pos': self.s_pos,
'cusum_neg': self.s_neg,
'action': 'CONTINUE_MONITORING'
}
# 使用示例
detector = CUSUMDetector(threshold=5.0, drift=0.5)
for pred, actual in zip(predictions, actuals):
error = pred - actual
result = detector.update(error)
if result['is_drifting']:
print(f"检测到漂移!方向: {result['direction']}")
break
CUSUM的优势:
- 能检测渐进的、微小的持续偏移
- 比单点检测更敏感
- 有明确的统计学基础
2.4 多指标综合检测
生产级推荐:组合多种检测方法,降低误报率。
class ComprehensiveDriftDetector:
"""综合漂移检测器"""
def __init__(self):
self.performance_monitor = PerformanceMonitor()
self.cusum_detector = CUSUMDetector()
def check_drift(self,
daily_return: float,
prediction_error: float,
training_features: np.array,
recent_features: np.array) -> dict:
results = {}
# 1. 性能监控
perf_result = self.performance_monitor.update(daily_return)
results['performance'] = perf_result
# 2. CUSUM检测
cusum_result = self.cusum_detector.update(prediction_error)
results['cusum'] = cusum_result
# 3. K-S检验(定期执行,如每周)
ks_result = detect_data_drift(training_features, recent_features)
results['ks_test'] = ks_result
# 综合判断:多数投票
drift_signals = [
perf_result.get('is_drifting', False),
cusum_result.get('is_drifting', False),
ks_result.get('is_drifting', False)
]
drift_count = sum(drift_signals)
results['overall'] = {
'drift_count': drift_count,
'is_drifting': drift_count >= 2, # 至少2个检测器报警
'confidence': drift_count / 3,
'recommendation': self._get_recommendation(drift_count)
}
return results
def _get_recommendation(self, drift_count: int) -> str:
if drift_count == 0:
return 'CONTINUE_NORMAL'
elif drift_count == 1:
return 'INCREASE_MONITORING'
elif drift_count == 2:
return 'PREPARE_RETRAIN'
else:
return 'IMMEDIATE_RETRAIN'
三、再训练策略
3.1 定期再训练
最简单的策略:按固定周期重新训练模型。
| 策略频率 | 周期 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|---|
| 日频策略 | 月度 | 中低频因子策略 | 简单可预测 | 可能滞后 |
| 周频策略 | 季度 | 组合配置策略 | 成本低 | 不适应突变 |
| 分钟级策略 | 周度 | 高频交易 | 及时更新 | 成本高 |
# 定期再训练调度器
class ScheduledRetrainer:
def __init__(self, retrain_frequency: str = 'monthly'):
self.frequency = retrain_frequency
self.last_retrain = None
def should_retrain(self, current_date) -> bool:
if self.last_retrain is None:
return True
if self.frequency == 'weekly':
return (current_date - self.last_retrain).days >= 7
elif self.frequency == 'monthly':
return (current_date - self.last_retrain).days >= 30
elif self.frequency == 'quarterly':
return (current_date - self.last_retrain).days >= 90
return False
3.2 触发式再训练
更智能的策略:检测到漂移时才触发再训练。
class TriggeredRetrainer:
"""触发式再训练器"""
def __init__(self,
performance_threshold: float = 0.3, # 夏普阈值
cusum_threshold: float = 5.0,
min_interval_days: int = 7): # 最小再训练间隔
self.performance_threshold = performance_threshold
self.cusum_threshold = cusum_threshold
self.min_interval_days = min_interval_days
self.last_retrain = None
self.detector = ComprehensiveDriftDetector()
def check_and_retrain(self, model, new_data, current_date) -> dict:
"""检查是否需要再训练,如需要则执行"""
# 防止过于频繁的再训练
if self.last_retrain:
days_since = (current_date - self.last_retrain).days
if days_since < self.min_interval_days:
return {'action': 'SKIP', 'reason': 'Too soon since last retrain'}
# 漂移检测
drift_result = self.detector.check_drift(...)
if drift_result['overall']['is_drifting']:
# 执行再训练
new_model = self._retrain(model, new_data)
self.last_retrain = current_date
return {
'action': 'RETRAINED',
'drift_confidence': drift_result['overall']['confidence'],
'new_model': new_model
}
return {'action': 'CONTINUE', 'drift_confidence': drift_result['overall']['confidence']}
3.3 在线学习
持续更新:不完全重训练,而是增量更新模型参数。
class OnlineLearner:
"""
在线学习更新器
适用场景:
- 需要快速适应市场变化
- 完全重训练成本过高
- 数据流持续到达
风险:
- 灾难性遗忘(忘记历史模式)
- 对噪音敏感
"""
def __init__(self, model, learning_rate: float = 0.001):
self.model = model
self.learning_rate = learning_rate
self.update_count = 0
def incremental_update(self, new_x, new_y):
"""
增量更新模型
使用较小的学习率进行单步梯度下降
"""
# 前向传播
prediction = self.model.predict(new_x)
error = new_y - prediction
# 反向传播(简化示意)
gradient = self._compute_gradient(new_x, error)
# 参数更新
for param, grad in zip(self.model.parameters(), gradient):
param -= self.learning_rate * grad
self.update_count += 1
return {
'prediction': prediction,
'error': error,
'update_count': self.update_count
}
def _compute_gradient(self, x, error):
# 实际实现取决于模型类型
pass
在线学习的陷阱:
- 灾难性遗忘:新数据覆盖旧知识
- 噪音累积:单样本更新容易被噪音误导
- 学习率敏感:太大→不稳定,太小→适应慢
3.4 混合策略(推荐)
最佳实践:结合定期和触发式再训练。
class HybridRetrainer:
"""混合再训练策略"""
def __init__(self):
self.scheduled_interval_days = 30 # 定期:每月
self.drift_detector = ComprehensiveDriftDetector()
self.last_scheduled_retrain = None
self.last_triggered_retrain = None
def should_retrain(self, current_date, metrics) -> dict:
"""判断是否需要再训练"""
# 检查定期再训练
scheduled_due = self._check_scheduled(current_date)
# 检查触发式再训练
drift_result = self.drift_detector.check_drift(metrics)
triggered_due = drift_result['overall']['is_drifting']
if scheduled_due and triggered_due:
return {
'should_retrain': True,
'reason': 'BOTH_SCHEDULED_AND_DRIFT',
'priority': 'HIGH'
}
elif triggered_due:
return {
'should_retrain': True,
'reason': 'DRIFT_DETECTED',
'priority': 'HIGH'
}
elif scheduled_due:
return {
'should_retrain': True,
'reason': 'SCHEDULED',
'priority': 'NORMAL'
}
return {'should_retrain': False, 'reason': 'NO_TRIGGER'}
四、再训练的最佳实践
4.1 训练数据选择
| 策略 | 描述 | 优点 | 缺点 |
|---|---|---|---|
| 扩展窗口 | 使用所有历史数据 | 样本量大 | 旧数据可能过时 |
| 滑动窗口 | 只用最近N天数据 | 适应新模式 | 可能丢失重要历史 |
| 加权窗口 | 近期数据权重高 | 平衡历史与当前 | 权重选择困难 |
推荐:滑动窗口 + 危机期数据保留
def prepare_training_data(all_data, window_days=252*2, keep_crisis=True):
"""准备再训练数据"""
# 滑动窗口
recent_data = all_data.iloc[-window_days:]
if keep_crisis:
# 保留重要危机期数据
crisis_periods = [
('2008-09', '2009-03'), # 金融危机
('2020-02', '2020-04'), # COVID
('2022-01', '2022-06'), # 加息冲击
]
crisis_data = []
for start, end in crisis_periods:
if start in all_data.index:
crisis_data.append(all_data.loc[start:end])
# 合并
training_data = pd.concat([recent_data] + crisis_data)
training_data = training_data.drop_duplicates()
return training_data
4.2 模型版本管理
# 模型版本管理
class ModelVersionManager:
def __init__(self, storage_path: str):
self.storage_path = storage_path
self.versions = []
def save_version(self, model, metrics: dict, reason: str):
"""保存模型版本"""
version_id = f"v{len(self.versions)+1}_{datetime.now():%Y%m%d_%H%M}"
version_info = {
'version_id': version_id,
'timestamp': datetime.now(),
'reason': reason,
'metrics': metrics,
'model_path': f"{self.storage_path}/{version_id}.pkl"
}
# 保存模型
joblib.dump(model, version_info['model_path'])
self.versions.append(version_info)
return version_id
def rollback(self, version_id: str):
"""回滚到指定版本"""
for v in self.versions:
if v['version_id'] == version_id:
return joblib.load(v['model_path'])
raise ValueError(f"Version {version_id} not found")
4.3 A/B测试
再训练后,不要直接替换旧模型,而是进行对比测试。
class ABTester:
"""模型A/B测试"""
def __init__(self, old_model, new_model, test_days: int = 5):
self.old_model = old_model
self.new_model = new_model
self.test_days = test_days
self.old_results = []
self.new_results = []
def run_comparison(self, data) -> dict:
"""运行对比测试"""
for day_data in data:
old_pred = self.old_model.predict(day_data)
new_pred = self.new_model.predict(day_data)
self.old_results.append(old_pred)
self.new_results.append(new_pred)
# 计算性能对比
old_sharpe = calculate_sharpe(self.old_results)
new_sharpe = calculate_sharpe(self.new_results)
improvement = (new_sharpe - old_sharpe) / abs(old_sharpe) if old_sharpe != 0 else 0
return {
'old_sharpe': old_sharpe,
'new_sharpe': new_sharpe,
'improvement': improvement,
'recommendation': 'DEPLOY_NEW' if improvement >`0`.1 else 'KEEP_OLD'
}
五、生产级漂移监控架构
前面介绍了漂移检测的理论方法,本节展示一个生产级漂移监控系统的实现架构。
5.1 核心设计模式
生产系统需要:
- 多指标监控:IC、PSI、Sharpe 三维度同时监控
- 可配置阈值:不同策略有不同的容忍度
- 持久化存储:漂移历史记录用于分析和审计
- 告警分级:区分警告和严重级别
AlertConfig 配置模式
from dataclasses import dataclass
@dataclass
class AlertConfig:
"""告警阈值配置"""
# IC (Information Coefficient) 阈值
ic_warning: float = 0.02 # IC < 0.02 触发警告
ic_critical: float = 0.01 # IC < 0.01 触发严重告警
# PSI (Population Stability Index) 阈值
psi_warning: float = 0.10 # PSI >`0`.10 分布有变化
psi_critical: float = 0.25 # PSI >`0`.25 分布显著变化
# Sharpe 阈值
sharpe_warning: float = 0.5 # 夏普 < 0.5 性能下降
sharpe_critical: float = 0.0 # 夏普 < 0 策略亏损
阈值解读:
| 指标 | 警告阈值 | 严重阈值 | 业务含义 |
|---|---|---|---|
| IC | < 0.02 | < 0.01 | 信号预测能力衰退 |
| PSI | >0.10 | >0.25 | 特征分布发生偏移 |
| Sharpe | < 0.5 | < 0.0 | 风险调整收益恶化 |
5.2 DriftMetrics 数据结构
每日计算并存储的漂移指标:
from dataclasses import dataclass
from datetime import date
@dataclass
class DriftMetrics:
"""每日漂移指标"""
date: date
strategy_id: str
# IC 指标(信息系数)
ic: float | None = None # 当日 IC
ic_5d_avg: float | None = None # 5日滚动平均
ic_20d_avg: float | None = None # 20日滚动平均
# PSI 指标(分布稳定性)
psi: float | None = None
psi_5d_avg: float | None = None
# Sharpe 指标(风险调整收益)
sharpe_5d: float | None = None # 5日夏普
sharpe_20d: float | None = None # 20日夏普
sharpe_60d: float | None = None # 60日夏普
# 业务指标
daily_return: float | None = None
cumulative_return: float | None = None
trade_count: int = 0
signal_count: int = 0
# 告警状态
ic_alert: bool = False
psi_alert: bool = False
sharpe_alert: bool = False
多时间窗口的意义:
- 5日窗口:快速响应,捕捉短期漂移
- 20日窗口:过滤噪音,确认趋势
- 60日窗口:长期基准,判断结构性变化
5.3 DriftMonitor 核心实现
import logging
import numpy as np
import psycopg
from psycopg.rows import dict_row
logger = logging.getLogger(__name__)
class DriftMonitor:
"""
生产级漂移监控服务
职责:
1. 计算 IC、PSI、Sharpe 指标
2. 与配置的阈值对比判断告警
3. 持久化到 PostgreSQL
4. 支持按策略隔离
"""
def __init__(self, dsn: str, strategy_id: str = "default"):
"""
Args:
dsn: PostgreSQL 连接字符串
strategy_id: 策略标识(支持多策略隔离)
"""
self.dsn = dsn
self.strategy_id = strategy_id
self._config: AlertConfig | None = None
def load_config(self) -> AlertConfig:
"""从数据库加载告警配置"""
with psycopg.connect(self.dsn) as conn:
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(
"""
SELECT ic_warning, ic_critical, psi_warning, psi_critical,
sharpe_warning, sharpe_critical
FROM drift_alert_config
WHERE strategy_id = %s
""",
(self.strategy_id,),
)
row = cur.fetchone()
if row:
self._config = AlertConfig(**row)
else:
self._config = AlertConfig() # 使用默认值
return self._config
def calculate_metrics(self, target_date: date) -> DriftMetrics:
"""
计算指定日期的所有漂移指标
核心逻辑:
1. 获取信号和收益,计算 IC
2. 获取历史收益,计算滚动 Sharpe
3. 根据阈值判断告警状态
"""
if self._config is None:
self.load_config()
metrics = DriftMetrics(date=target_date, strategy_id=self.strategy_id)
# 计算 IC(信号与收益的相关性)
signals, returns = self.get_signals_and_returns(target_date)
if len(signals) >`0` and len(returns) >`0`:
metrics.ic = calculate_ic(signals, returns)
metrics.signal_count = len(signals)
# 计算滚动 Sharpe
daily_returns = self.get_daily_returns(lookback_days=60)
if len(daily_returns) >= 5:
metrics.sharpe_5d = calculate_sharpe(daily_returns[-5:])
if len(daily_returns) >= 20:
metrics.sharpe_20d = calculate_sharpe(daily_returns[-20:])
if len(daily_returns) >= 60:
metrics.sharpe_60d = calculate_sharpe(daily_returns)
# 判断告警状态
config = self._config or AlertConfig()
if metrics.ic is not None:
metrics.ic_alert = metrics.ic < config.ic_critical
if metrics.psi is not None:
metrics.psi_alert = metrics.psi > config.psi_critical
if metrics.sharpe_20d is not None:
metrics.sharpe_alert = metrics.sharpe_20d < config.sharpe_critical
return metrics
5.4 PostgreSQL 持久化
漂移指标需要持久化以支持:
- 历史趋势分析
- 合规审计
- 再训练决策依据
def save_metrics(self, metrics: DriftMetrics) -> None:
"""保存指标到数据库(支持幂等更新)"""
with psycopg.connect(self.dsn) as conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO drift_metrics (
date, strategy_id, ic, ic_5d_avg, ic_20d_avg,
psi, psi_5d_avg, sharpe_5d, sharpe_20d, sharpe_60d,
daily_return, cumulative_return, trade_count, signal_count,
ic_alert, psi_alert, sharpe_alert
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
%s, %s, %s, %s, %s, %s, %s
)
ON CONFLICT (date, strategy_id) DO UPDATE SET
ic = EXCLUDED.ic,
sharpe_20d = EXCLUDED.sharpe_20d,
ic_alert = EXCLUDED.ic_alert,
psi_alert = EXCLUDED.psi_alert,
sharpe_alert = EXCLUDED.sharpe_alert
""",
(
metrics.date, metrics.strategy_id, metrics.ic,
metrics.ic_5d_avg, metrics.ic_20d_avg, metrics.psi,
metrics.psi_5d_avg, metrics.sharpe_5d, metrics.sharpe_20d,
metrics.sharpe_60d, metrics.daily_return,
metrics.cumulative_return, metrics.trade_count,
metrics.signal_count, metrics.ic_alert,
metrics.psi_alert, metrics.sharpe_alert,
),
)
conn.commit()
logger.info(f"Saved drift metrics for {metrics.date}")
数据库表结构:
CREATE TABLE drift_metrics (
date DATE NOT NULL,
strategy_id VARCHAR(64) NOT NULL,
ic FLOAT,
ic_5d_avg FLOAT,
ic_20d_avg FLOAT,
psi FLOAT,
psi_5d_avg FLOAT,
sharpe_5d FLOAT,
sharpe_20d FLOAT,
sharpe_60d FLOAT,
daily_return FLOAT,
cumulative_return FLOAT,
trade_count INT DEFAULT 0,
signal_count INT DEFAULT 0,
ic_alert BOOLEAN DEFAULT FALSE,
psi_alert BOOLEAN DEFAULT FALSE,
sharpe_alert BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (date, strategy_id)
);
CREATE TABLE drift_alert_config (
strategy_id VARCHAR(64) PRIMARY KEY,
ic_warning FLOAT DEFAULT 0.02,
ic_critical FLOAT DEFAULT 0.01,
psi_warning FLOAT DEFAULT 0.10,
psi_critical FLOAT DEFAULT 0.25,
sharpe_warning FLOAT DEFAULT 0.5,
sharpe_critical FLOAT DEFAULT 0.0
);
5.5 每日监控作业
def run_daily(self, target_date: date | None = None) -> DriftMetrics:
"""
每日漂移监控作业入口
典型部署:通过 cron 或 Airflow 每日收盘后执行
"""
if target_date is None:
target_date = date.today()
logger.info(f"Running drift monitoring for {target_date}")
metrics = self.calculate_metrics(target_date)
self.save_metrics(metrics)
# 告警日志
if metrics.ic_alert:
logger.warning(f"IC ALERT: IC={metrics.ic:.4f} below threshold")
if metrics.psi_alert:
logger.warning(f"PSI ALERT: PSI={metrics.psi:.4f} above threshold")
if metrics.sharpe_alert:
logger.warning(f"SHARPE ALERT: Sharpe={metrics.sharpe_20d:.4f} below threshold")
return metrics
5.6 集成示例:何时触发再训练
结合漂移监控与再训练决策:
class RetrainOrchestrator:
"""再训练编排器"""
def __init__(self, drift_monitor: DriftMonitor):
self.monitor = drift_monitor
self.consecutive_alerts = 0
self.alert_threshold = 3 # 连续3天告警才触发
def check_retrain_needed(self, target_date: date) -> dict:
"""
判断是否需要触发再训练
规则:
1. IC 连续3天 < 0.01 -> 触发
2. PSI 单次 >`0`.25 -> 触发
3. 20日 Sharpe < 0 -> 触发
"""
metrics = self.monitor.run_daily(target_date)
# 统计连续告警
if metrics.ic_alert or metrics.sharpe_alert:
self.consecutive_alerts += 1
else:
self.consecutive_alerts = 0
# 判断触发条件
triggers = []
if self.consecutive_alerts >= self.alert_threshold:
triggers.append(f"IC/Sharpe连续{self.consecutive_alerts}天告警")
if metrics.psi_alert:
triggers.append(f"PSI={metrics.psi:.3f}超过严重阈值")
if metrics.sharpe_20d is not None and metrics.sharpe_20d < 0:
triggers.append(f"20日Sharpe={metrics.sharpe_20d:.2f}为负")
should_retrain = len(triggers) >`0`
return {
'should_retrain': should_retrain,
'triggers': triggers,
'metrics': metrics,
'action': 'RETRAIN' if should_retrain else 'CONTINUE'
}
# 使用示例
monitor = DriftMonitor(
dsn="postgres://trading:trading@localhost:5432/trading",
strategy_id="momentum_v2"
)
orchestrator = RetrainOrchestrator(monitor)
result = orchestrator.check_retrain_needed(date.today())
if result['should_retrain']:
print(f"触发再训练,原因: {result['triggers']}")
# 调用再训练流水线
5.7 架构要点总结
| 组件 | 职责 | 关键设计 |
|---|---|---|
| AlertConfig | 阈值配置 | 数据类,支持从DB加载 |
| DriftMetrics | 指标载体 | 多时间窗口,告警状态 |
| DriftMonitor | 核心服务 | 计算+存储+告警 |
| PostgreSQL | 持久化 | 幂等写入,支持审计 |
| RetrainOrchestrator | 决策编排 | 连续告警计数,多条件触发 |
生产部署建议:
- 每日收盘后 T+30min 执行(等待数据就绪)
- 告警接入 Slack/PagerDuty
- 监控仪表板展示 IC/PSI/Sharpe 趋势图
- 再训练触发后自动进入 A/B 测试流程
六、总结
检测方法速查
| 方法 | 检测对象 | 敏感度 | 计算成本 | 推荐场景 |
|---|---|---|---|---|
| 性能监控 | 策略收益 | 中 | 低 | 所有策略(必备) |
| K-S检验 | 特征分布 | 高 | 中 | 定期检查(周/月) |
| 卡方检验 | 分类特征 | 高 | 低 | 市场状态标签 |
| CUSUM | 预测误差 | 高 | 低 | 持续监控(日) |
| 综合检测 | 多维度 | 最高 | 中 | 生产系统(推荐) |
再训练策略速查
| 策略 | 触发方式 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|---|
| 定期 | 时间驱动 | 简单可预测 | 可能滞后 | 稳定市场 |
| 触发式 | 漂移驱动 | 及时响应 | 复杂度高 | 波动市场 |
| 在线学习 | 持续更新 | 最快适应 | 不稳定 | 高频场景 |
| 混合 | 定期+触发 | 平衡 | 需要调参 | 生产推荐 |
核心认知:模型漂移不是"如果"而是"何时"的问题。建立完善的检测和再训练机制,是量化策略长期存活的关键。