请问这代码的问题出在哪里?应该如何修改

用户头像小林的号
2025-09-01 发布

import pandas as pd
import numpy as np
from datetime import datetime

class BacktestEngine:
"""量化交易回测引擎"""
def init(self):

核心参数初始化

self.security = '600660.SH' # 交易标的
self.initial_capital = 1000000 # 初始资金100万
self.cash = self.initial_capital
self.positions = {} # 持仓记录
self.trades = [] # 交易记录

# 策略参数
self.entry_price = None  # 进场价格
self.atr_value = None  # ATR值
self.last_trade_date = None  # 上次交易日

# 数据管理
self.data = None
self.data_start = '2020-01-01'
self.data_end = '2025-12-31'
self.data_generated = False

def generate_data(self):
"""生成模拟市场数据"""
print(f"生成交易数据 {self.data_start} 至 {self.data_end}")

# 创建工作日范围
dates = pd.date_range(self.data_start, self.data_end, freq='B')

# 基础价格序列(线性趋势+随机波动)
base_price = np.linspace(20, 50, len(dates))
volatility = np.random.normal(0, 2, len(dates))
close_prices = base_price + volatility

# 生成高、低价(基于收盘价)
highs = close_prices + np.random.uniform(0.5, 3.0, len(dates))
lows = close_prices - np.random.uniform(0.5, 3.0, len(dates))
volumes = np.random.randint(100000, 500000, len(dates))

# 创建DataFrame
self.data = pd.DataFrame({
    'date': dates,
    'open': close_prices - np.random.uniform(0, 1.0, len(dates)),
    'high': highs,
    'low': lows,
    'close': close_prices,
    'volume': volumes
}).set_index('date')

# 标记数据已生成
self.data_generated = True
return self.data

def calculate_indicators(self):
"""计算技术指标"""
if self.data is None or len(self.data) == 0:
raise ValueError("无法计算指标 - 没有可用数据")

# 计算EMA指标
self.data['ema2'] = self.data['close'].ewm(span=2, adjust=False).mean()
self.data['ema10_high'] = self.data['high'].ewm(span=10, adjust=False).mean()
self.data['ema120'] = self.data['close'].ewm(span=120, adjust=False).mean()

# 计算ATR指标
high_low = self.data['high'] - self.data['low']
high_close = np.abs(self.data['high'] - self.data['close'].shift())
low_close = np.abs(self.data['low'] - self.data['close'].shift())
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
self.data['atr'] = tr.rolling(14).mean()

return self.data

def run_backtest(self):
"""执行回测"""
if not self.data_generated:
self.generate_data()
self.calculate_indicators()

# 打印回测信息
print(f"回测开始 | 时间段: {self.data.index[0].date()} 至 {self.data.index[-1].date()}")
print(f"初始资金: {self.initial_capital:,.2f}")
print(f"数据量: {len(self.data)}个交易日")

# 主回测循环
for idx, (date, row) in enumerate(self.data.iterrows()):
    self.process_day(date, row)
  
    # 进度跟踪
    if (idx + 1) % 100 == 0:
        print(f"处理进度: {idx+1}/{len(self.data)} ({date.date()})")

# 回测结果
self.show_results()

def process_day(self, date, row):
"""处理单个交易日"""
# 跳过已有持仓日
if self.last_trade_date == date.date():
return

# 检查买入条件
if (row['ema2'] > row['ema10_high'] and 
    row['close'] > row['ema120'] and 
    self.security not in self.positions and 
    self.cash > row['close']):
  
    # 执行买入
    self.execute_buy(date, row)
  
# 检查卖出条件
elif self.security in self.positions:
    self.check_sell_conditions(date, row)

def execute_buy(self, date, row):
"""执行买入操作"""
# 计算可买数量
amount = int(self.cash / row['close'])
if amount <= 0:
return

# 更新账户状态
self.positions[self.security] = {
    'amount': amount,
    'entry_price': row['close']
}
self.cash -= amount * row['close']
self.entry_price = row['close']
self.atr_value = row['atr']
self.last_trade_date = date.date()

# 记录交易
self.trades.append({
    'date': date,
    'type': 'buy',
    'price': row['close'],
    'amount': amount
})
print(f"{date.date()} | 买入 {self.security} | "
      f"价格: {row['close']:.2f} | 数量: {amount} | "
      f"剩余现金: {self.cash:,.2f}")

def check_sell_conditions(self, date, row):
"""检查卖出条件"""
position = self.positions[self.security]

# 计算双止损条件
stop_pct = self.entry_price * 0.99
stop_atr = self.entry_price - self.atr_value
stop_price = min(stop_pct, stop_atr)

# 满足止损条件
if row['close'] < stop_price:
    self.execute_sell(date, row, position)

def execute_sell(self, date, row, position):
"""执行卖出操作"""
# 计算头寸价值
amount = position['amount']
position_value = amount * row['close']

# 更新账户状态
self.cash += position_value
profit = position_value - (amount * position['entry_price'])

# 记录交易
self.trades.append({
    'date': date,
    'type': 'sell',
    'price': row['close'],
    'amount': amount,
    'profit': profit
})

# 清理持仓和状态
del self.positions[self.security]
self.entry_price = None
self.atr_value = None
self.last_trade_date = date.date()

print(f"{date.date()} | 卖出 {self.security} | "
      f"价格: {row['close']:.2f} | 数量: {amount} | "
      f"盈利: {profit:,.2f} | 现金余额: {self.cash:,.2f}")

def get_portfolio_value(self):
"""计算组合总价值"""
if not self.positions:
return self.cash

last_price = self.data.iloc[-1]['close']
position_value = sum(pos['amount'] * last_price for pos in self.positions.values())
return self.cash + position_value

def show_results(self):
"""展示回测结果"""
print("\n===== 回测结果 =====")
portfolio_value = self.get_portfolio_value()
total_return = (portfolio_value / self.initial_capital - 1) * 100

# 交易统计
trades = [t for t in self.trades if t['type'] == 'sell']
profitable = [t for t in trades if t['profit'] > 0]
win_rate = len(profitable) / len(trades) * 100 if trades else 0

print(f"初始资金: {self.initial_capital:,.2f}")
print(f"组合总值: {portfolio_value:,.2f} ({total_return:.2f}%)")
print(f"现金余额: {self.cash:,.2f}")
print(f"总交易次数: {len(self.trades)}")
print(f"盈利交易占比: {win_rate:.2f}%")
print(f"平均每笔盈利: {sum(t['profit'] for t in trades)/len(trades):.2f}" if trades else "无交易记录")

运行示例

if name == "main":
print("===== 量化交易回测系统启动 =====")
engine = BacktestEngine()
engine.run_backtest()

评论