AI应用 量化交易MLOps系统架构部署运维

量化交易系统的AI化部署与运维

回测赚钱的策略成千上万,实盘赚钱的系统万里挑一。差距不在策略本身,而在工程化能力。

从Jupyter到生产系统

原型与生产的鸿沟

Jupyter原型                    生产系统
─────────────────────────────────────────────────
手动运行                      定时自动执行
数据在内存中                    数据持久化和缓存
单标的单周期                    多标的多周期
无异常处理                      全面的错误处理和告警
手动查看结果                    自动监控和通知
代码散落各处                    模块化、可测试

生产级系统架构

┌─────────────────────────────────────────────┐
│                  调度层                       │
│         Airflow / Prefect / Cron            │
├─────────────────────────────────────────────┤
│  数据层      │  计算层      │  执行层        │
│ ─────────── │ ─────────── │ ───────────    │
│ 行情数据采集 │ 特征工程     │ 订单管理        │
│ 另类数据采集 │ 模型推理     │ 风控检查        │
│ 数据质量检查 │ 信号生成     │ 执行报告        │
├─────────────────────────────────────────────┤
│             存储层                            │
│  PostgreSQL + Redis + InfluxDB              │
├─────────────────────────────────────────────┤
│             监控层                            │
│     Grafana + Prometheus + 飞书/钉钉告警     │
└─────────────────────────────────────────────┘

数据管道的AI化

数据质量自动检测

import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import List, Dict

@dataclass
class DataQualityReport:
    total_rows: int
    missing_pct: float
    duplicates: int
    outliers_count: Dict[str, int]
    stale_data: bool
    schema_valid: bool
    issues: List[str]

class AIDataValidator:
    """AI辅助的数据质量检查器"""

    def __init__(self, llm_client=None):
        self.llm = llm_client

    def validate_ohlcv(self, df: pd.DataFrame) -> DataQualityReport:
        """验证OHLCV数据"""
        issues = []

        # 基础检查
        required_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
        missing_cols = [c for c in required_cols if c not in df.columns]

        if missing_cols:
            issues.append(f"缺少列: {missing_cols}")
            schema_valid = False
        else:
            schema_valid = True
            # 逻辑检查
            if (df['High'] < df['Low']).any():
                issues.append("High < Low 的异常行存在")
            if (df['Close'] > df['High']).any():
                issues.append("Close > High 的异常行存在")
            if (df['Close'] < df['Low']).any():
                issues.append("Close < Low 的异常行存在")

        # 缺失值检查
        missing_pct = df[required_cols].isnull().mean().mean()
        if missing_pct > 0.01:
            issues.append(f"缺失率 {missing_pct:.2%} 超过1%阈值")

        # 异常值检测(基于滚动统计)
        outliers = {}
        if 'Close' in df.columns:
            returns = df['Close'].pct_change()
            zscore = (returns - returns.rolling(252).mean()) / returns.rolling(252).std()
            outlier_count = (abs(zscore) > 5).sum()
            if outlier_count > 0:
                outliers['returns'] = int(outlier_count)
                issues.append(f"收益率异常值: {outlier_count}个 (>5σ)")

        # 数据时效性
        stale_data = (pd.Timestamp.now() - df.index.max()).days > 2
        if stale_data:
            issues.append(f"数据滞后 {(pd.Timestamp.now() - df.index.max()).days} 天")

        return DataQualityReport(
            total_rows=len(df),
            missing_pct=missing_pct,
            duplicates=int(df.index.duplicated().sum()),
            outliers_count=outliers,
            stale_data=stale_data,
            schema_valid=schema_valid,
            issues=issues
        )

    def ai_diagnose(self, report: DataQualityReport) -> str:
        """让AI帮助诊断数据问题"""
        if not self.llm or not report.issues:
            return ""

        prompt = f"""
        量化交易系统的数据质量检查发现以下问题:
        {chr(10).join(f'- {i}' for i in report.issues)}

        请诊断:
        1. 这些问题的根本原因可能是什么?
        2. 哪些问题会导致策略信号异常?
        3. 建议的修复优先级是什么?
        """
        return self.llm.generate(prompt)

模型服务的AI化

自动化特征工程

from sklearn.feature_selection import mutual_info_regression
import warnings

class AutoFeatureEngineer:
    """自动特征工程流水线"""

    def __init__(self, max_features=50):
        self.max_features = max_features
        self.selected_features = None
        self.feature_importance = None

    def generate_features(self, df):
        """自动生成候选特征"""
        features = {}

        # 价格特征
        for window in [5, 10, 20, 50, 200]:
            features[f'return_{window}d'] = df['Close'].pct_change(window)
            features[f'volatility_{window}d'] = df['Close'].pct_change().rolling(window).std()
            features[f'volume_ratio_{window}d'] = df['Volume'] / df['Volume'].rolling(window).mean()
            features[f'high_low_ratio_{window}d'] = (
                df['High'].rolling(window).max() / df['Low'].rolling(window).min()
            )

        # 价格形态(基于滑动窗口的统计)
        for window in [20, 50]:
            roll = df['Close'].rolling(window)
            features[f'price_position_{window}d'] = (
                (df['Close'] - roll.min()) / (roll.max() - roll.min())
            )
            features[f'skewness_{window}d'] = df['Close'].pct_change().rolling(window).skew()
            features[f'kurtosis_{window}d'] = df['Close'].pct_change().rolling(window).kurt()

        return pd.DataFrame(features, index=df.index).dropna()

    def select_features(self, X, y, method='mutual_info'):
        """自动特征选择"""
        if method == 'mutual_info':
            mi_scores = mutual_info_regression(X.fillna(0), y, random_state=42)
            self.feature_importance = pd.Series(mi_scores, index=X.columns).sort_values(ascending=False)

        self.selected_features = self.feature_importance.head(self.max_features).index.tolist()
        print(f"选择了 {len(self.selected_features)}/{len(X.columns)} 个特征")
        print(f"Top 5: {self.selected_features[:5]}")

        return self.selected_features

模型版本管理与热更新

import mlflow
import mlflow.sklearn
from datetime import datetime
import hashlib

class ModelRegistry:
    """模型注册和版本管理"""

    def __init__(self, tracking_uri="sqlite:///model_registry.db"):
        mlflow.set_tracking_uri(tracking_uri)
        self.experiment_name = "quant_trading"

    def register_model(self, model, metrics, params, feature_names):
        """注册模型版本"""
        mlflow.set_experiment(self.experiment_name)

        with mlflow.start_run(run_name=f"train_{datetime.now():%Y%m%d_%H%M}"):
            # 记录参数
            mlflow.log_params(params)

            # 记录指标
            mlflow.log_metrics({
                'sharpe_ratio': metrics['sharpe'],
                'max_drawdown': metrics['max_dd'],
                'win_rate': metrics['win_rate'],
                'profit_factor': metrics['profit_factor'],
                'calmar_ratio': metrics['calmar'],
            })

            # 记录特征
            feature_hash = hashlib.md5(
                ','.join(sorted(feature_names)).encode()
            ).hexdigest()[:8]
            mlflow.log_param('feature_hash', feature_hash)
            mlflow.log_param('n_features', len(feature_names))

            # 保存模型
            mlflow.sklearn.log_model(model, "model")

            run_id = mlflow.active_run().info.run_id
            print(f"模型已注册: run_id={run_id}")
            return run_id

    def promote_to_production(self, run_id):
        """将模型提升到生产环境"""
        client = mlflow.tracking.MlflowClient()
        model_uri = f"runs:/{run_id}/model"

        # 注册模型
        mv = mlflow.register_model(model_uri, "trading_model")

        # 标记为生产版本
        client.transition_model_version_stage(
            name="trading_model",
            version=mv.version,
            stage="Production"
        )
        print(f"模型版本 {mv.version} 已提升至生产环境")

    def load_production_model(self):
        """加载生产环境模型"""
        return mlflow.sklearn.load_model("models:/trading_model/Production")

实时推理引擎

高性能特征服务

import redis
import pickle
from concurrent.futures import ThreadPoolExecutor
import time

class RealTimeInferenceEngine:
    """实时推理引擎"""

    def __init__(self, model, feature_config, redis_url="redis://localhost:6379"):
        self.model = model
        self.feature_config = feature_config
        self.cache = redis.from_url(redis_url)
        self.executor = ThreadPoolExecutor(max_workers=4)
        self.latency_window = []  # 推理延迟记录

    def infer(self, symbol: str, current_data: dict) -> dict:
        """实时推理"""
        start_time = time.perf_counter()

        # 1. 从缓存获取历史特征
        cached_features = self._get_cached_features(symbol)
        if cached_features is None:
            return {'error': '缓存未命中,请检查数据流水线'}

        # 2. 合并实时数据 + 历史特征
        features = self._compute_features(cached_features, current_data)

        # 3. 模型预测
        prediction = self.model.predict_proba([features])[0]

        # 4. 风控检查
        risk_check = self._risk_check(symbol, prediction)

        # 5. 生成信号
        signal = self._generate_signal(prediction, risk_check)

        latency = time.perf_counter() - start_time
        self._record_latency(latency)

        return {
            'symbol': symbol,
            'signal': signal,
            'confidence': float(prediction[1]),
            'risk_approved': risk_check['approved'],
            'latency_ms': latency * 1000,
            'timestamp': datetime.now().isoformat()
        }

    def _risk_check(self, symbol, prediction):
        """多层风控"""
        checks = {
            'max_position': self._check_position_limit(symbol),
            'max_exposure': self._check_exposure_limit(),
            'circuit_breaker': self._check_circuit_breaker(),
            'volatility_filter': self._check_volatility(symbol),
        }
        return {
            'approved': all(checks.values()),
            'details': checks
        }

监控与告警

策略健康度监控

class StrategyMonitor:
    """策略运行监控"""

    def __init__(self, alert_webhook=None):
        self.alert_webhook = alert_webhook
        self.metrics_history = []

    def check_health(self, strategy_state: dict) -> List[str]:
        """检查策略健康状态"""
        alerts = []

        # 1. 信号频率异常
        signal_count = strategy_state.get('signals_today', 0)
        avg_signals = strategy_state.get('avg_daily_signals', 10)
        if signal_count > avg_signals * 3:
            alerts.append(f"⚠️ 今日信号异常增多: {signal_count} (均值: {avg_signals})")
        if signal_count == 0 and datetime.now().hour > 11:
            alerts.append(f"⚠️ 今日无任何信号产生")

        # 2. 数据延迟
        data_lag = strategy_state.get('data_lag_minutes', 0)
        if data_lag > 5:
            alerts.append(f"🚨 数据延迟 {data_lag} 分钟,可能影响信号质量")

        # 3. 模型漂移
        prediction_mean = strategy_state.get('prediction_mean', 0)
        if abs(prediction_mean - 0.5) > 0.2:
            alerts.append(f"⚠️ 模型预测分布偏移: 均值={prediction_mean:.3f}")

        # 4. 回撤监控
        current_drawdown = strategy_state.get('drawdown', 0)
        if current_drawdown > strategy_state.get('max_allowed_dd', 0.15):
            alerts.append(f"🚨 回撤超过限制: {current_drawdown:.2%}")

        # 5. AI综合诊断
        if len(alerts) >= 3 and self.alert_webhook:
            diagnosis = self._ai_diagnose(alerts, strategy_state)
            alerts.append(f"🤖 AI诊断: {diagnosis}")

        return alerts

    def _ai_diagnose(self, alerts, state):
        """AI辅助诊断异常"""
        prompt = f"""
        量化策略出现以下告警:
        {chr(10).join(alerts)}

        策略状态:
        夏普比率: {state.get('sharpe', 'N/A')}
        最近5日收益: {state.get('returns_5d', 'N/A')}
        持仓: {state.get('positions', 'N/A')}

        请给出最可能的1-2个根因和紧急行动建议。不超过50字。
        """
        # 调用LLM
        return "模型漂移 + 市场风格切换,建议暂停交易并重新训练"

关键检查清单

在将AI量化系统从回测推向实盘前,确认以下事项:

上线前 Checklist

□ 样本外回测至少覆盖一个完整牛熊周期
□ 考虑了真实的交易成本(佣金、滑点、印花税)
□ 没有前视偏差(所有特征在预测时点可用)
□ 模型在多个不相关标的上表现一致
□ 压力测试:极端行情(2015股灾、2020熔断)下的表现
□ 金丝雀测试:小资金实盘运行至少1个月
□ 监控和告警系统就位
□ 紧急停止开关可用
□ 有完整的数据备份和恢复方案
□ 定期模型重训练流程已自动化

从回测到实盘,90%的努力不在策略算法上,而在工程韧性上。一个好的量化系统,即使策略失效,也不会造成灾难性损失——这才是真正的生产级。