★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
项目介绍
2022年卡塔尔世界杯(FIFA World Cup Qatar 2022)是第二十二届国际足联世界杯,于当地时间2022年11月20日(北京时间11月21日)至12月18日在卡塔尔境内5座城市中的8座球场举行(赛程将原本的32天减至29天)。卡塔尔是继日本、韩国后,第三个主办世界杯足球赛的亚洲国家,也是首个主办的伊斯兰国家,同时亦是二战后首个从未晋级过世界杯决赛圈的主办国。本届世界杯总花费高达2290亿美元,被称为“史上最贵世界杯”。
项目使用历史数据国际足联世界排名1992-2022和1872年至2022年国际足球成绩完成世界杯预测。
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import joblib
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, roc_curve, roc_auc_score, confusion_matrix
数据处理
results.csv包括以下列:
date - 比赛日期home_team - 主队名称away_team - 客队名称home_score - 全场主队得分,包括加时赛,不包括点球大战away_score - 全场客队得分,包括加时赛,不包括点球大战tournament - 比赛名称city - 比赛所在的城市/城镇/行政单位的名称country - 比赛所在国的名称neutral - TRUE/FALSE 列,指示比赛是否在中立场地进行
# 读取数据
results = pd.read_csv('/home/aistudio/work/results.csv', parse_dates=['date'])
results.head()
datehome_teamaway_teamhome_scoreaway_scoretournamentcitycountryneutral01872-11-30ScotlandEngland00FriendlyGlasgowScotlandFalse11873-03-08EnglandScotland42FriendlyLondonEnglandFalse21874-03-07ScotlandEngland21FriendlyGlasgowScotlandFalse31875-03-06EnglandScotland22FriendlyLondonEnglandFalse41876-03-04ScotlandEngland30FriendlyGlasgowScotlandFalse
# 查看数据信息
results.info()
RangeIndex: 44289 entries, 0 to 44288
Data columns (total 9 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date 44289 non-null datetime64[ns]
1 home_team 44289 non-null object
2 away_team 44289 non-null object
3 home_score 44289 non-null int64
4 away_score 44289 non-null int64
5 tournament 44289 non-null object
6 city 44289 non-null object
7 country 44289 non-null object
8 neutral 44289 non-null bool
dtypes: bool(1), datetime64[ns](1), int64(2), object(5)
memory usage: 2.7+ MB
# 检查数据是否缺失
results.isna().sum()
date 0
home_team 0
away_team 0
home_score 0
away_score 0
tournament 0
city 0
country 0
neutral 0
dtype: int64
# 筛选1992-2022世界杯预选赛和世界杯正式赛
fifa_data = results[(results['date'] >= '1992-12-31') & ((results['tournament'] == 'FIFA World Cup') | (results['tournament'] == 'FIFA World Cup qualification'))]
fifa_data = fifa_data.drop(['tournament'], axis=1)
fifa_data = fifa_data.reset_index(drop=True)
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutral01993-01-10AngolaZimbabwe11LuandaAngolaFalse11993-01-10DR CongoCameroon12KinshasaZaïreFalse21993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse31993-01-16TanzaniaZambia13MwanzaTanzaniaFalse41993-01-17BeninTunisia05CotonouBeninFalse
# 查看数据信息
fifa_data.info()
RangeIndex: 6359 entries, 0 to 6358
Data columns (total 8 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date 6359 non-null datetime64[ns]
1 home_team 6359 non-null object
2 away_team 6359 non-null object
3 home_score 6359 non-null int64
4 away_score 6359 non-null int64
5 city 6359 non-null object
6 country 6359 non-null object
7 neutral 6359 non-null bool
dtypes: bool(1), datetime64[ns](1), int64(2), object(4)
memory usage: 354.1+ KB
fifa_ranking.csv包括以下列:
rank - 当前国家/地区排名country_full - 国家全名country_abrv - 国家缩写total_points - 当前总分previous_points - 上次评分的总分rank_change - 自上次发布以来排名如何变化confederation - 国际足联联合会rank_date - 评级计算日期
# 读取数据
fifa_ranking = pd.read_csv('/home/aistudio/work/fifa_ranking.csv', parse_dates=['rank_date'])
fifa_ranking.head()
rankcountry_fullcountry_abrvtotal_pointsprevious_pointsrank_changeconfederationrank_date074MadagascarMAD18.00.00CAF1992-12-31152QatarQAT27.00.00AFC1992-12-31251SenegalSEN27.00.00CAF1992-12-31350El SalvadorSLV28.00.00CONCACAF1992-12-31449Korea RepublicKOR28.00.00AFC1992-12-31
# 替换国家全名: 部分国家全名在fifa_ranking和results中存在差异
fifa_ranking['country_full'] = fifa_ranking['country_full'].str.replace('Brunei Darussalam', 'Brunei').str.replace('Cape Verde Islands', 'Cape Verde').str.replace('chinese taipei', 'taiwan').str.replace('Congo DR', 'DR Congo').str.replace("Côte d'Ivoire", 'Ivory Coast').str.replace('Curacao', 'Curaçao').str.replace('IR Iran', 'Iran').str.replace('Kyrgyz Republic', 'Kyrgyzstan').str.replace('Korea DPR', 'North Korea').str.replace('Korea Republic', 'South Korea').str.replace('St Kitts and Nevis', 'Saint Kitts and Nevis').str.replace('St Lucia', 'Saint Lucia').str.replace('St Vincent and the Grenadines', 'Saint Vincent and the Grenadines').str.replace('Sao Tome e Principe', 'São Tomé and Príncipe').str.replace('US Virgin Islands', 'United States Virgin Islands').str.replace('USA', 'United States')
# fifa_ranking以日期为索引、根据国家分组、按天重新采样、最后重置索引
fifa_ranking = fifa_ranking.set_index(['rank_date']).groupby(['country_full'], group_keys=False).resample('D').fillna(method='ffill').reset_index()
fifa_ranking.head()
rank_daterankcountry_fullcountry_abrvtotal_pointsprevious_pointsrank_changeconfederation02003-01-15204AfghanistanAFG7.00.00AFC12003-01-16204AfghanistanAFG7.00.00AFC22003-01-17204AfghanistanAFG7.00.00AFC32003-01-18204AfghanistanAFG7.00.00AFC42003-01-19204AfghanistanAFG7.00.00AFC
# 合并数据: 联合results和fifa_ranking
fifa_data = fifa_data.merge(fifa_ranking[['country_full', 'total_points', 'previous_points', 'rank', 'rank_change', 'rank_date']], left_on=['date', 'home_team'], right_on=['rank_date', 'country_full']).drop(['rank_date', 'country_full'], axis=1)
fifa_data = fifa_data.merge(fifa_ranking[['country_full', 'total_points', 'previous_points', 'rank', 'rank_change', 'rank_date']], left_on=['date', 'away_team'], right_on=['rank_date', 'country_full'], suffixes=('_home', '_away')).drop(['rank_date', 'country_full'], axis=1)
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutraltotal_points_homeprevious_points_homerank_homerank_change_hometotal_points_awayprevious_points_awayrank_awayrank_change_away01993-01-10AngolaZimbabwe11LuandaAngolaFalse10.00.0102027.00.054011993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse5.00.0124050.00.013021993-01-16TanzaniaZambia13MwanzaTanzaniaFalse15.00.080038.00.032031993-01-17BeninTunisia05CotonouBeninFalse4.00.0127035.00.038041993-01-17BotswanaIvory Coast00GaboroneBotswanaFalse2.00.0139041.00.0270
# 查看数据信息
fifa_data.info()
Int64Index: 6052 entries, 0 to 6051
Data columns (total 16 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date 6052 non-null datetime64[ns]
1 home_team 6052 non-null object
2 away_team 6052 non-null object
3 home_score 6052 non-null int64
4 away_score 6052 non-null int64
5 city 6052 non-null object
6 country 6052 non-null object
7 neutral 6052 non-null bool
8 total_points_home 6052 non-null float64
9 previous_points_home 6052 non-null float64
10 rank_home 6052 non-null int64
11 rank_change_home 6052 non-null int64
12 total_points_away 6052 non-null float64
13 previous_points_away 6052 non-null float64
14 rank_away 6052 non-null int64
15 rank_change_away 6052 non-null int64
dtypes: bool(1), datetime64[ns](1), float64(4), int64(6), object(4)
memory usage: 762.4+ KB
# 检查数据是否缺失
fifa_data.isna().sum()
date 0
home_team 0
away_team 0
home_score 0
away_score 0
city 0
country 0
neutral 0
total_points_home 0
previous_points_home 0
rank_home 0
rank_change_home 0
total_points_away 0
previous_points_away 0
rank_away 0
rank_change_away 0
dtype: int64
特征工程
特征工程
result - 比赛结果 0: 主队胜 1: 客队胜 2: 平局home_points - 主队得分 3: 主队胜 0: 客队胜 1: 平局away_points - 客队得分 3: 客队胜 0: 主队胜 1: 平局target - 预测目标 0: 主队胜 1: 客队胜或者平局
# 特征工程
def get_result(home_score, away_score):
if home_score > away_score:
return pd.Series([0, 3, 0, 0])
elif home_score < away_score:
return pd.Series([1, 0, 3, 1])
else:
return pd.Series([2, 1, 1, 1])
results = fifa_data.apply(lambda x: get_result(x['home_score'], x['away_score']), axis=1)
fifa_data[['result', 'home_points', 'away_points', 'target']] = results
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutraltotal_points_homeprevious_points_homerank_homerank_change_hometotal_points_awayprevious_points_awayrank_awayrank_change_awayresulthome_pointsaway_pointstarget01993-01-10AngolaZimbabwe11LuandaAngolaFalse10.00.0102027.00.0540211111993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse5.00.0124050.00.0130211121993-01-16TanzaniaZambia13MwanzaTanzaniaFalse15.00.080038.00.0320103131993-01-17BeninTunisia05CotonouBeninFalse4.00.0127035.00.0380103141993-01-17BotswanaIvory Coast00GaboroneBotswanaFalse2.00.0139041.00.02702111
# 特征编码
label_encoder = LabelEncoder()
labels = ['date', 'home_team', 'away_team', 'city', 'country']
for label in labels:
fifa_data[f'{label}_encoding'] = label_encoder.fit_transform(fifa_data[label])
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutraltotal_points_homeprevious_points_home...rank_change_awayresulthome_pointsaway_pointstargetdate_encodinghome_team_encodingaway_team_encodingcity_encodingcountry_encoding01993-01-10AngolaZimbabwe11LuandaAngolaFalse10.00.0...0211105206345411993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse5.00.0...02111117013627417121993-01-16TanzaniaZambia13MwanzaTanzaniaFalse15.00.0...01031118320541018431993-01-17BeninTunisia05CotonouBeninFalse4.00.0...010312211891572041993-01-17BotswanaIvory Coast00GaboroneBotswanaFalse2.00.0...021112269421625
5 rows × 25 columns
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(fifa_data.corr(), annot=True, linewidths=0.2, square=True)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-n7IVlUkB-1687232982510)(main_files/main_20_0.png)]
# 删除编码: 相关性太低
fifa_data = fifa_data.drop(['date_encoding', 'home_team_encoding', 'away_team_encoding', 'city_encoding', 'country_encoding'], axis=1)
特征工程
rank_diff - 排名差异rank_change_diff - 排名变化差异total_points_diff - 总分差异previous_points_diff - 上次评分的总分差异home_points2rank - 主队得分 / 客队排名away_points2rank - 客队得分 / 主队排名points2rank_diff - points2rank差异
# 特征工程
fifa_data['rank_diff'] = fifa_data['rank_home'] - fifa_data['rank_away']
fifa_data['rank_change_diff'] = fifa_data['rank_change_home'] - fifa_data['rank_change_away']
fifa_data['total_points_diff'] = fifa_data['total_points_home'] - fifa_data['total_points_away']
fifa_data['previous_points_diff'] = fifa_data['previous_points_home'] - fifa_data['previous_points_away']
fifa_data['home_points2rank'] = fifa_data['home_points'] / fifa_data['rank_away']
fifa_data['away_points2rank'] = fifa_data['away_points'] / fifa_data['rank_home']
fifa_data['points2rank_diff'] = fifa_data['home_points2rank'] - fifa_data['away_points2rank']
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(fifa_data.corr(), annot=True, linewidths=0.2, square=True)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jRXkEsOh-1687232982511)(main_files/main_24_0.png)]
# 分离数据: 优化特征工程
home_team = fifa_data[['date', 'home_team', 'home_score', 'away_score', 'total_points_home', 'total_points_away', 'previous_points_home', 'previous_points_away', 'rank_home', 'rank_away', 'home_points', 'away_points', 'home_points2rank', 'away_points2rank', 'result']]
away_team = fifa_data[['date', 'away_team', 'away_score', 'home_score', 'total_points_away', 'total_points_home', 'previous_points_away', 'previous_points_home', 'rank_away', 'rank_home', 'away_points', 'home_points', 'away_points2rank', 'home_points2rank', 'result']]
home_team.columns = [h.replace('home_', '').replace('_home', '').replace('away_', 'rival_').replace('_away', '_rival') for h in home_team.columns]
away_team.columns = [a.replace('away_', '').replace('_away', '').replace('home_', 'rival_').replace('_home', '_rival') for a in away_team.columns]
# 合并数据: 优化特征工程
team_data = home_team.append(away_team)
data_copy = team_data.copy()
team_data.head()
dateteamscorerival_scoretotal_pointstotal_points_rivalprevious_pointsprevious_points_rivalrankrank_rivalpointsrival_pointspoints2rankrival_points2rankresult01993-01-10Angola1110.027.00.00.010254110.0185190.009804211993-01-16South Africa005.050.00.00.012413110.0769230.008065221993-01-16Tanzania1315.038.00.00.08032030.0000000.037500131993-01-17Benin054.035.00.00.012738030.0000000.023622141993-01-17Botswana002.041.00.00.013927110.0370370.0071942
特征工程
mean_goals - 平均进球mean_goals_last5 - 最近五场平均进球rival_mean_goals - 对手平均进球rival_mean_goals_last5 - 对手最近五场平均进球mean_rank - 平均排名mean_rank_last5 - 最近五场平均排名rival_mean_rank - 对手平均排名rival_mean_rank_last5 - 对手最近五场平均排名mean_points - 平均得分mean_points_last5 - 最近五场平均得分rival_mean_points - 对手平均得分rival_mean_points_last5 - 对手最近五场平均得分mean_points2rank - 平均points2rankmean_points2rank_last5 - 最近五场平均points2rankrival_mean_points2rank - 对手平均points2rankrival_mean_points2rank_last5 - 对手最近五场平均points2rank
# 特征工程
team_values = []
for idx, row in team_data.iterrows():
team = row['team']
date = row['date']
pasts = team_data.loc[(team_data['team'] == team) & (team_data['date'] < date)].sort_values(by=['date'], ascending=False)
last5 = pasts.head(5)
mean_goals = pasts['score'].mean()
mean_goals_last5 = last5['score'].mean()
rival_mean_goals = pasts['rival_score'].mean()
rival_mean_goals_last5 = last5['rival_score'].mean()
mean_rank = pasts['rank'].mean()
mean_rank_last5 = last5['rank'].mean()
rival_mean_rank = pasts['rank_rival'].mean()
rival_mean_rank_last5 = last5['rank_rival'].mean()
mean_points = pasts['points'].mean()
mean_points_last5 = last5['points'].mean()
rival_mean_points = pasts['rival_points'].mean()
rival_mean_points_last5 = last5['rival_points'].mean()
mean_points2rank = pasts['points2rank'].mean()
mean_points2rank_last5 = last5['points2rank'].mean()
rival_mean_points2rank = pasts['rival_points2rank'].mean()
rival_mean_points2rank_last5 = last5['rival_points2rank'].mean()
team_values.append([mean_goals, mean_goals_last5, rival_mean_goals, rival_mean_goals_last5, mean_rank, mean_rank_last5, rival_mean_rank, rival_mean_rank_last5, mean_points, mean_points_last5, rival_mean_points, rival_mean_points_last5, mean_points2rank, mean_points2rank_last5, rival_mean_points2rank, rival_mean_points2rank_last5])
# 合并数据
team_columns = ['mean_goals', 'mean_goals_last5', 'rival_mean_goals', 'rival_mean_goals_last5', 'mean_rank', 'mean_rank_last5', 'rival_mean_rank', 'rival_mean_rank_last5', 'mean_points', 'mean_points_last5', 'rival_mean_points', 'rival_mean_points_last5', 'mean_points2rank', 'mean_points2rank_last5', 'rival_mean_points2rank', 'rival_mean_points2rank_last5']
team_value = pd.DataFrame(team_values, columns=team_columns)
team_data = pd.concat([team_data.reset_index(drop=True), team_value], axis=1, ignore_index=False)
team_data.head()
dateteamscorerival_scoretotal_pointstotal_points_rivalprevious_pointsprevious_points_rivalrankrank_rival...rival_mean_rankrival_mean_rank_last5mean_pointsmean_points_last5rival_mean_pointsrival_mean_points_last5mean_points2rankmean_points2rank_last5rival_mean_points2rankrival_mean_points2rank_last501993-01-10Angola1110.027.00.00.010254...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN11993-01-16South Africa005.050.00.00.012413...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN21993-01-16Tanzania1315.038.00.00.08032...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN31993-01-17Benin054.035.00.00.012738...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN41993-01-17Botswana002.041.00.00.013927...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
5 rows × 31 columns
# 查看数据信息
team_data.info()
RangeIndex: 12104 entries, 0 to 12103
Data columns (total 31 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date 12104 non-null datetime64[ns]
1 team 12104 non-null object
2 score 12104 non-null int64
3 rival_score 12104 non-null int64
4 total_points 12104 non-null float64
5 total_points_rival 12104 non-null float64
6 previous_points 12104 non-null float64
7 previous_points_rival 12104 non-null float64
8 rank 12104 non-null int64
9 rank_rival 12104 non-null int64
10 points 12104 non-null int64
11 rival_points 12104 non-null int64
12 points2rank 12104 non-null float64
13 rival_points2rank 12104 non-null float64
14 result 12104 non-null int64
15 mean_goals 11897 non-null float64
16 mean_goals_last5 11897 non-null float64
17 rival_mean_goals 11897 non-null float64
18 rival_mean_goals_last5 11897 non-null float64
19 mean_rank 11897 non-null float64
20 mean_rank_last5 11897 non-null float64
21 rival_mean_rank 11897 non-null float64
22 rival_mean_rank_last5 11897 non-null float64
23 mean_points 11897 non-null float64
24 mean_points_last5 11897 non-null float64
25 rival_mean_points 11897 non-null float64
26 rival_mean_points_last5 11897 non-null float64
27 mean_points2rank 11897 non-null float64
28 mean_points2rank_last5 11897 non-null float64
29 rival_mean_points2rank 11897 non-null float64
30 rival_mean_points2rank_last5 11897 non-null float64
dtypes: datetime64[ns](1), float64(22), int64(7), object(1)
memory usage: 2.9+ MB
# 分离数据
home_team_data = team_data.iloc[:int(team_data.shape[0] / 2), :]
away_team_data = team_data.iloc[int(team_data.shape[0] / 2):, :]
away_team_data.tail()
dateteamscorerival_scoretotal_pointstotal_points_rivalprevious_pointsprevious_points_rivalrankrank_rival...rival_mean_rankrival_mean_rank_last5mean_pointsmean_points_last5rival_mean_pointsrival_mean_points_last5mean_points2rankmean_points2rank_last5rival_mean_points2rankrival_mean_points2rank_last5120992022-06-01Ukraine311535.081472.661535.081471.822739...62.43373559.01.7590361.80.8915660.60.0684960.0934120.0283190.023407121002022-06-05Ukraine011535.081588.081535.081578.012718...62.15476242.01.7738102.20.8809520.40.0685960.1071830.0279820.015407121012022-06-07Australia211462.291356.991486.861353.104268...81.50476265.61.9809521.00.8095241.60.0296320.0113210.0217020.044029121022022-06-13Peru001562.321462.291563.451486.862242...34.66165435.61.0902262.01.6766920.80.0674190.0658480.0447020.036364121032022-06-14New Zealand011206.071503.091161.661464.0610131...109.125000156.21.9000003.00.9250000.00.0219970.0192610.0101790.000000
5 rows × 31 columns
# 分离数据
home_team_data = home_team_data[home_team_data.columns[-16:]]
away_team_data = away_team_data[away_team_data.columns[-16:]]
home_team_data.columns = ['home_' + str(col) for col in home_team_data.columns]
away_team_data.columns = ['away_' + str(col) for col in away_team_data.columns]
away_team_data.tail()
away_mean_goalsaway_mean_goals_last5away_rival_mean_goalsaway_rival_mean_goals_last5away_mean_rankaway_mean_rank_last5away_rival_mean_rankaway_rival_mean_rank_last5away_mean_pointsaway_mean_points_last5away_rival_mean_pointsaway_rival_mean_points_last5away_mean_points2rankaway_mean_points2rank_last5away_rival_mean_points2rankaway_rival_mean_points2rank_last5120991.4939761.60.8072291.039.26506026.062.43373559.01.7590361.80.8915660.60.0684960.0934120.0283190.023407121001.5119051.80.8095240.839.11904826.462.15476242.01.7738102.20.8809520.40.0685960.1071830.0279820.015407121012.5142861.40.8095241.242.88571435.681.50476265.61.9809521.00.8095241.60.0296320.0113210.0217020.044029121021.0150381.21.4661650.646.75188022.434.66165435.61.0902262.01.6766920.80.0674190.0658480.0447020.036364121032.2000003.60.9000000.2101.725000111.0109.125000156.21.9000003.00.9250000.00.0219970.0192610.0101790.000000
# 合并数据
team_data = pd.concat([home_team_data, away_team_data.reset_index(drop=True)], axis=1, ignore_index=False)
fifa_data = pd.concat([fifa_data, team_data.reset_index(drop=True)], axis=1, ignore_index=False)
fifa_data.columns
Index(['date', 'home_team', 'away_team', 'home_score', 'away_score', 'city',
'country', 'neutral', 'total_points_home', 'previous_points_home',
'rank_home', 'rank_change_home', 'total_points_away',
'previous_points_away', 'rank_away', 'rank_change_away', 'result',
'home_points', 'away_points', 'target', 'rank_diff', 'rank_change_diff',
'total_points_diff', 'previous_points_diff', 'home_points2rank',
'away_points2rank', 'points2rank_diff', 'home_mean_goals',
'home_mean_goals_last5', 'home_rival_mean_goals',
'home_rival_mean_goals_last5', 'home_mean_rank', 'home_mean_rank_last5',
'home_rival_mean_rank', 'home_rival_mean_rank_last5',
'home_mean_points', 'home_mean_points_last5', 'home_rival_mean_points',
'home_rival_mean_points_last5', 'home_mean_points2rank',
'home_mean_points2rank_last5', 'home_rival_mean_points2rank',
'home_rival_mean_points2rank_last5', 'away_mean_goals',
'away_mean_goals_last5', 'away_rival_mean_goals',
'away_rival_mean_goals_last5', 'away_mean_rank', 'away_mean_rank_last5',
'away_rival_mean_rank', 'away_rival_mean_rank_last5',
'away_mean_points', 'away_mean_points_last5', 'away_rival_mean_points',
'away_rival_mean_points_last5', 'away_mean_points2rank',
'away_mean_points2rank_last5', 'away_rival_mean_points2rank',
'away_rival_mean_points2rank_last5'],
dtype='object')
# 分离数据
fifa_data = fifa_data[['date', 'home_team', 'away_team', 'rank_home', 'rank_away', 'home_score', 'away_score', 'result', 'rank_diff', 'rank_change_diff', 'total_points_diff', 'previous_points_diff', 'points2rank_diff', 'home_mean_goals', 'home_mean_goals_last5', 'home_rival_mean_goals', 'home_rival_mean_goals_last5', 'home_mean_rank', 'home_mean_rank_last5', 'home_rival_mean_rank', 'home_rival_mean_rank_last5', 'home_mean_points', 'home_mean_points_last5', 'home_rival_mean_points', 'home_rival_mean_points_last5', 'home_mean_points2rank', 'home_mean_points2rank_last5', 'home_rival_mean_points2rank', 'home_rival_mean_points2rank_last5', 'away_mean_goals', 'away_mean_goals_last5', 'away_rival_mean_goals', 'away_rival_mean_goals_last5', 'away_mean_rank', 'away_mean_rank_last5', 'away_rival_mean_rank', 'away_rival_mean_rank_last5', 'away_mean_points', 'away_mean_points_last5', 'away_rival_mean_points', 'away_rival_mean_points_last5', 'away_mean_points2rank', 'away_mean_points2rank_last5', 'away_rival_mean_points2rank', 'away_rival_mean_points2rank_last5', 'target']]
fifa_data.head()
datehome_teamaway_teamrank_homerank_awayhome_scoreaway_scoreresultrank_diffrank_change_diff...away_rival_mean_rank_last5away_mean_pointsaway_mean_points_last5away_rival_mean_pointsaway_rival_mean_points_last5away_mean_points2rankaway_mean_points2rank_last5away_rival_mean_points2rankaway_rival_mean_points2rank_last5target01993-01-10AngolaZimbabwe10254112480...NaNNaNNaNNaNNaNNaNNaNNaNNaN111993-01-16South AfricaNigeria124130021110...NaNNaNNaNNaNNaNNaNNaNNaNNaN121993-01-16TanzaniaZambia8032131480...NaNNaNNaNNaNNaNNaNNaNNaNNaN131993-01-17BeninTunisia12738051890...NaNNaNNaNNaNNaNNaNNaNNaNNaN141993-01-17BotswanaIvory Coast139270021120...NaNNaNNaNNaNNaNNaNNaNNaNNaN1
5 rows × 46 columns
# 检查数据是否缺失
fifa_data.isna().sum()
date 0
home_team 0
away_team 0
rank_home 0
rank_away 0
home_score 0
away_score 0
result 0
rank_diff 0
rank_change_diff 0
total_points_diff 0
previous_points_diff 0
points2rank_diff 0
home_mean_goals 101
home_mean_goals_last5 101
home_rival_mean_goals 101
home_rival_mean_goals_last5 101
home_mean_rank 101
home_mean_rank_last5 101
home_rival_mean_rank 101
home_rival_mean_rank_last5 101
home_mean_points 101
home_mean_points_last5 101
home_rival_mean_points 101
home_rival_mean_points_last5 101
home_mean_points2rank 101
home_mean_points2rank_last5 101
home_rival_mean_points2rank 101
home_rival_mean_points2rank_last5 101
away_mean_goals 106
away_mean_goals_last5 106
away_rival_mean_goals 106
away_rival_mean_goals_last5 106
away_mean_rank 106
away_mean_rank_last5 106
away_rival_mean_rank 106
away_rival_mean_rank_last5 106
away_mean_points 106
away_mean_points_last5 106
away_rival_mean_points 106
away_rival_mean_points_last5 106
away_mean_points2rank 106
away_mean_points2rank_last5 106
away_rival_mean_points2rank 106
away_rival_mean_points2rank_last5 106
target 0
dtype: int64
# 缺失值处理
fifa_data = fifa_data.dropna().reset_index(drop=True)
fifa_data.isna().sum()
date 0
home_team 0
away_team 0
rank_home 0
rank_away 0
home_score 0
away_score 0
result 0
rank_diff 0
rank_change_diff 0
total_points_diff 0
previous_points_diff 0
points2rank_diff 0
home_mean_goals 0
home_mean_goals_last5 0
home_rival_mean_goals 0
home_rival_mean_goals_last5 0
home_mean_rank 0
home_mean_rank_last5 0
home_rival_mean_rank 0
home_rival_mean_rank_last5 0
home_mean_points 0
home_mean_points_last5 0
home_rival_mean_points 0
home_rival_mean_points_last5 0
home_mean_points2rank 0
home_mean_points2rank_last5 0
home_rival_mean_points2rank 0
home_rival_mean_points2rank_last5 0
away_mean_goals 0
away_mean_goals_last5 0
away_rival_mean_goals 0
away_rival_mean_goals_last5 0
away_mean_rank 0
away_mean_rank_last5 0
away_rival_mean_rank 0
away_rival_mean_rank_last5 0
away_mean_points 0
away_mean_points_last5 0
away_rival_mean_points 0
away_rival_mean_points_last5 0
away_mean_points2rank 0
away_mean_points2rank_last5 0
away_rival_mean_points2rank 0
away_rival_mean_points2rank_last5 0
target 0
dtype: int64
# 查看数据信息
fifa_data.info()
RangeIndex: 5894 entries, 0 to 5893
Data columns (total 46 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date 5894 non-null datetime64[ns]
1 home_team 5894 non-null object
2 away_team 5894 non-null object
3 rank_home 5894 non-null int64
4 rank_away 5894 non-null int64
5 home_score 5894 non-null int64
6 away_score 5894 non-null int64
7 result 5894 non-null int64
8 rank_diff 5894 non-null int64
9 rank_change_diff 5894 non-null int64
10 total_points_diff 5894 non-null float64
11 previous_points_diff 5894 non-null float64
12 points2rank_diff 5894 non-null float64
13 home_mean_goals 5894 non-null float64
14 home_mean_goals_last5 5894 non-null float64
15 home_rival_mean_goals 5894 non-null float64
16 home_rival_mean_goals_last5 5894 non-null float64
17 home_mean_rank 5894 non-null float64
18 home_mean_rank_last5 5894 non-null float64
19 home_rival_mean_rank 5894 non-null float64
20 home_rival_mean_rank_last5 5894 non-null float64
21 home_mean_points 5894 non-null float64
22 home_mean_points_last5 5894 non-null float64
23 home_rival_mean_points 5894 non-null float64
24 home_rival_mean_points_last5 5894 non-null float64
25 home_mean_points2rank 5894 non-null float64
26 home_mean_points2rank_last5 5894 non-null float64
27 home_rival_mean_points2rank 5894 non-null float64
28 home_rival_mean_points2rank_last5 5894 non-null float64
29 away_mean_goals 5894 non-null float64
30 away_mean_goals_last5 5894 non-null float64
31 away_rival_mean_goals 5894 non-null float64
32 away_rival_mean_goals_last5 5894 non-null float64
33 away_mean_rank 5894 non-null float64
34 away_mean_rank_last5 5894 non-null float64
35 away_rival_mean_rank 5894 non-null float64
36 away_rival_mean_rank_last5 5894 non-null float64
37 away_mean_points 5894 non-null float64
38 away_mean_points_last5 5894 non-null float64
39 away_rival_mean_points 5894 non-null float64
40 away_rival_mean_points_last5 5894 non-null float64
41 away_mean_points2rank 5894 non-null float64
42 away_mean_points2rank_last5 5894 non-null float64
43 away_rival_mean_points2rank 5894 non-null float64
44 away_rival_mean_points2rank_last5 5894 non-null float64
45 target 5894 non-null int64
dtypes: datetime64[ns](1), float64(35), int64(8), object(2)
memory usage: 2.1+ MB
# 分离数据
data1 = fifa_data[list(fifa_data.columns[8:13].values) + ['target']]
data2 = fifa_data[list(fifa_data.columns[13:29].values) + ['target']]
data3 = fifa_data[fifa_data.columns[29:]]
# 查看数据
data1.tail()
rank_diffrank_change_difftotal_points_diffprevious_points_diffpoints2rank_difftarget588912-1-62.42-63.26-0.07692315890-9-253.0042.930.1111110589126-6-105.30-133.76-0.04411815892205-100.03-76.590.02164515893-70-1297.02302.400.0297030
# 查看数据
data2.tail()
home_mean_goalshome_mean_goals_last5home_rival_mean_goalshome_rival_mean_goals_last5home_mean_rankhome_mean_rank_last5home_rival_mean_rankhome_rival_mean_rank_last5home_mean_pointshome_mean_points_last5home_rival_mean_pointshome_rival_mean_points_last5home_mean_points2rankhome_mean_points2rank_last5home_rival_mean_points2rankhome_rival_mean_points2rank_last5target58891.3108111.80.9864860.445.22973044.661.27027081.61.6756763.01.0675680.00.0531580.1021650.0261650.000000158901.3380282.21.3802821.056.47887319.260.74647953.61.2676062.21.4788730.40.0366440.2381730.0339250.021053058911.6500000.81.1400000.477.24000069.494.39000060.41.4700001.81.3500001.20.0157800.0341880.0176140.017391158922.5094341.60.8113211.242.87735837.281.37735864.21.9905661.40.8018871.40.0297680.0174780.0214970.038147158931.5572521.21.0076340.243.22900844.849.96946637.41.7251912.61.0381680.20.0529680.0977190.0303130.0040820
# 查看数据
data3.tail()
away_mean_goalsaway_mean_goals_last5away_rival_mean_goalsaway_rival_mean_goals_last5away_mean_rankaway_mean_rank_last5away_rival_mean_rankaway_rival_mean_rank_last5away_mean_pointsaway_mean_points_last5away_rival_mean_pointsaway_rival_mean_points_last5away_mean_points2rankaway_mean_points2rank_last5away_rival_mean_points2rankaway_rival_mean_points2rank_last5target58891.4939761.60.8072291.039.26506026.062.43373559.01.7590361.80.8915660.60.0684960.0934120.0283190.023407158901.5119051.80.8095240.839.11904826.462.15476242.01.7738102.20.8809520.40.0685960.1071830.0279820.015407058912.5142861.40.8095241.242.88571435.681.50476265.61.9809521.00.8095241.60.0296320.0113210.0217020.044029158921.0150381.21.4661650.646.75188022.434.66165435.61.0902262.01.6766920.80.0674190.0658480.0447020.036364158932.2000003.60.9000000.2101.725000111.0109.125000156.21.9000003.00.9250000.00.0219970.0192610.0101790.0000000
# 小提琴图
standard1 = (data1[:-1] - data1[:-1].mean()) / data1[:-1].std()
standard1['target'] = data1["target"]
violin1 = pd.melt(standard1, id_vars='target', var_name='feature', value_name='value')
standard2 = (data2[:-1] - data2[:-1].mean()) / data2[:-1].std()
standard2['target'] = data2['target']
violin2 = pd.melt(standard2, id_vars='target', var_name='feature', value_name='value')
standard3 = (data3[:-1] - data3[:-1].mean()) / data3[:-1].std()
standard3['target'] = data3['target']
violin3 = pd.melt(standard3, id_vars='target', var_name='feature', value_name='value')
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin1, split=True, inner='quart')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VDLPHHQO-1687232982513)(main_files/main_43_0.png)]
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard1.corr(), annot=True, linewidths=0.2, square=True)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KKGUw1fF-1687232982513)(main_files/main_44_0.png)]
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin2, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZribidUS-1687232982513)(main_files/main_45_0.png)]
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard2.corr(), annot=True, linewidths=0.2, square=True)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sbxm4ITn-1687232982514)(main_files/main_46_0.png)]
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin3, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-G5sugYMJ-1687232982514)(main_files/main_47_0.png)]
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard3.corr(), annot=True, linewidths=0.2, square=True)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oZvfKsNr-1687232982514)(main_files/main_48_0.png)]
特征工程
mean_goals_diff - 平均进球差异mean_goals_last5_diff - 最近五场平均进球差异rival_mean_goals_diff - 对手平均进球差异rival_mean_goals_last5_diff - 对手最近五场平均进球差异mean_rank_diff - 平均排名差异mean_rank_last5_diff - 最近五场平均排名差异rival_mean_rank_diff - 对手平均排名差异rival_mean_rank_last5_diff - 对手最近五场平均排名差异mean_points_diff - 平均得分差异mean_points_last5_diff - 最近五场平均得分差异rival_mean_points_diff - 对手平均得分差异rival_mean_points_last5_diff - 对手最近五场平均得分差异mean_points2rank_diff - 平均points2rank差异mean_points2rank_last5_diff - 最近五场平均points2rank差异rival_mean_points2rank_diff - 对手平均points2rank差异rival_mean_points2rank_last5_diff - 对手最近五场平均points2rank差异
# 特征工程
data = fifa_data.copy()
data.loc[:, 'mean_goals_diff'] = data['home_mean_goals'] - data['away_mean_goals']
data.loc[:, 'mean_goals_last5_diff'] = data['home_mean_goals_last5'] - data['away_mean_goals_last5']
data.loc[:, 'rival_mean_goals_diff'] = data['home_rival_mean_goals'] - data['away_rival_mean_goals']
data.loc[:, 'rival_mean_goals_last5_diff'] = data['home_rival_mean_goals_last5'] - data['away_rival_mean_goals_last5']
data.loc[:, 'mean_rank_diff'] = data['home_mean_rank'] - data['away_mean_rank']
data.loc[:, 'mean_rank_last5_diff'] = data['home_mean_rank_last5'] - data['away_mean_rank_last5']
data.loc[:, 'rival_mean_rank_diff'] = data['home_rival_mean_rank'] - data['away_rival_mean_rank']
data.loc[:, 'rival_mean_rank_last5_diff'] = data['home_rival_mean_rank_last5'] - data['away_rival_mean_rank_last5']
data.loc[:, 'mean_points_diff'] = data['home_mean_points'] - data['away_mean_points']
data.loc[:, 'mean_points_last5_diff'] = data['home_mean_points_last5'] - data['away_mean_points_last5']
data.loc[:, 'rival_mean_points_diff'] = data['home_rival_mean_points'] - data['away_rival_mean_points']
data.loc[:, 'rival_mean_points_last5_diff'] = data['home_rival_mean_points_last5'] - data['away_rival_mean_points_last5']
data.loc[:, 'mean_points2rank_diff'] = data['home_mean_points2rank'] - data['away_mean_points2rank']
data.loc[:, 'mean_points2rank_last5_diff'] = data['home_mean_points2rank_last5'] - data['away_mean_points2rank_last5']
data.loc[:, 'rival_mean_points2rank_diff'] = data['home_rival_mean_points2rank'] - data['away_rival_mean_points2rank']
data.loc[:, 'rival_mean_points2rank_last5_diff'] = data['home_rival_mean_points2rank_last5'] - data['away_rival_mean_points2rank_last5']
data_diff1 = data.iloc[:, -16:]
standard_diff1 = (data_diff1 - data_diff1.mean()) / data_diff1.std()
standard_diff1['target'] = data['target']
violin_diff1 = pd.melt(standard_diff1, id_vars='target', var_name='feature', value_name='value')
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin_diff1, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-n0IRvedt-1687232982515)(main_files/main_50_0.png)]
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard_diff1.corr(), annot=True, linewidths=0.2, square=True)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-76Szt1d2-1687232982515)(main_files/main_51_0.png)]
特征工程
mean_goals2mean_rank_diff - 主队平均进球 / 主队平均排名 - 客队平均进球 / 客队平均排名rival_mean_goals2mean_rank_diff - 主队对手平均进球 / 主队平均排名 - 客队对手平均进球 / 客队平均排名mean_goals2mean_rank_last5_diff - 主队最近五场平均进球 / 主队平均排名 - 客队最近五场平均进球 / 客队平均排名rival_mean_goals2mean_rank_last5_diff - 主队对手最近五场平均进球 / 主队平均排名 - 客队对手最近五场平均进球 / 客队平均排名mean_points2mean_rank_diff - 主队平均得分 / 主队平均排名 - 客队平均得分 / 客队平均排名rival_mean_points2mean_rank_diff - 主队对手平均得分 / 主队平均排名 - 客队对手平均得分 / 客队平均排名mean_points2mean_rank_last5_diff - 主队最近五场平均得分 / 主队平均排名 - 客队最近五场平均得分 / 客队平均排名rival_mean_points2mean_rank_last5_diff - 主队对手最近五场平均得分 / 主队平均排名 - 客队对手最近五场平均得分 / 客队平均排名
# 特征工程
data.loc[:, 'mean_goals2mean_rank_diff'] = (data['home_mean_goals'] / data['home_mean_rank']) - (data['away_mean_goals'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_goals2mean_rank_diff'] = (data['home_rival_mean_goals'] / data['home_mean_rank']) - (data['away_rival_mean_goals'] / data['away_mean_rank'])
data.loc[:, 'mean_goals2mean_rank_last5_diff'] = (data['home_mean_goals_last5'] / data['home_mean_rank']) - (data['away_mean_goals_last5'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_goals2mean_rank_last5_diff'] = (data['home_rival_mean_goals_last5'] / data['home_mean_rank']) - (data['away_rival_mean_goals_last5'] / data['away_mean_rank'])
data.loc[:, 'mean_points2mean_rank_diff'] = (data['home_mean_points'] / data['home_mean_rank']) - (data['away_mean_points'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_points2mean_rank_diff'] = (data['home_rival_mean_points'] / data['home_mean_rank']) - (data['away_rival_mean_points'] / data['away_mean_rank'])
data.loc[:, 'mean_points2mean_rank_last5_diff'] = (data['home_mean_points_last5'] / data['home_mean_rank']) - (data['away_mean_points_last5'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_points2mean_rank_last5_diff'] = (data['home_rival_mean_points_last5'] / data['home_mean_rank']) - (data['away_rival_mean_points_last5'] / data['away_mean_rank'])
data_diff2 = data.iloc[:, -8:]
standard_diff2 = (data_diff2 - data_diff2.mean()) / data_diff2.std()
standard_diff2['target'] = data['target']
violin_diff2 = pd.melt(standard_diff2, id_vars='target', var_name='feature', value_name='value')
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin_diff2, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Wi3FI5HD-1687232982515)(main_files/main_53_0.png)]
# 绘制箱型图
plt.figure(figsize=(15, 10))
sns.boxplot(x='feature', y='value', hue='target', data=violin_diff2)
plt.xticks(rotation=90)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Co0svD98-1687232982516)(main_files/main_54_0.png)]
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard_diff2.corr(), annot=True, linewidths=0.2, square=True)
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jXDTL6sD-1687232982516)(main_files/main_55_0.png)]
筛选相关性大于0.3的特征
rank_difftotal_points_diffprevious_points_diffaway_mean_rankaway_mean_rank_last5away_mean_pointsaway_rival_mean_pointsmean_goals_diffmean_goals_last5_diffrival_mean_goals_diffrival_mean_goals_last5_diffmean_rank_diffmean_rank_last5_diffmean_points_diffmean_points_last5_diffrival_mean_points_diffrival_mean_points_last5_diffmean_points2rank_diff
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='total_points_diff', y='previous_points_diff', data=data, kind='reg')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Muow1e9n-1687232982516)(main_files/main_57_1.png)]
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='away_mean_rank', y='away_mean_rank_last5', data=data, kind='reg')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZRQKzFUl-1687232982517)(main_files/main_58_1.png)]
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='mean_goals_diff', y='mean_goals_last5_diff', data=data, kind='reg')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-e7gIdzEW-1687232982517)(main_files/main_59_1.png)]
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='rival_mean_goals_diff', y='rival_mean_goals_last5_diff', data=data, kind='reg')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PXncyckZ-1687232982517)(main_files/main_60_1.png)]
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='mean_rank_diff', y='mean_rank_last5_diff', data=data, kind='reg')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oz25GiUz-1687232982517)(main_files/main_61_1.png)]
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='mean_points_diff', y='mean_points_last5_diff', data=data, kind='reg')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vegzj9yv-1687232982518)(main_files/main_62_1.png)]
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='rival_mean_points_diff', y='rival_mean_points_last5_diff', data=data, kind='reg')
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ky2J8sn2-1687232982519)(main_files/main_63_1.png)]
删除分布相似的特征
rank_difftotal_points_diffaway_mean_rankaway_mean_rank_last5away_mean_pointsaway_rival_mean_pointsmean_goals_diffmean_goals_last5_diffrival_mean_goals_diffmean_rank_diffmean_rank_last5_diffmean_points_diffmean_points_last5_diffrival_mean_points_diffrival_mean_points_last5_diffmean_points2rank_diff
# 构建训练数据
fifa_data = data[['home_team', 'away_team', 'target', 'rank_diff', 'total_points_diff', 'away_mean_rank', 'away_mean_rank_last5', 'away_mean_points', 'away_rival_mean_points', 'mean_goals_diff', 'mean_goals_last5_diff', 'rival_mean_goals_diff', 'mean_rank_diff', 'mean_rank_last5_diff', 'mean_points_diff', 'mean_points_last5_diff', 'rival_mean_points_diff', 'rival_mean_points_last5_diff', 'mean_points2rank_diff']]
fifa_data.head()
home_teamaway_teamtargetrank_difftotal_points_diffaway_mean_rankaway_mean_rank_last5away_mean_pointsaway_rival_mean_pointsmean_goals_diffmean_goals_last5_diffrival_mean_goals_diffmean_rank_diffmean_rank_last5_diffmean_points_diffmean_points_last5_diffrival_mean_points_diffrival_mean_points_last5_diffmean_points2rank_diff0EgyptTogo0-8035.0101.0101.00.03.0-1.0-1.0-2.0-80.0-80.01.01.0-2.0-2.00.0098041MoroccoBenin0-8628.0127.0127.00.03.01.01.0-5.0-86.0-86.03.03.0-3.0-3.00.0352942TunisiaEthiopia0-4721.085.085.00.03.05.05.0-1.0-47.0-47.03.03.0-3.0-3.00.0236223ZimbabweAngola0-4817.0102.0102.01.01.01.01.00.5-48.0-48.01.01.0-0.5-0.5-0.0133154AlgeriaGhana0-95.039.039.03.00.0-1.0-1.00.0-9.0-9.0-2.0-2.01.01.0-0.020000
# 查看数据信息
fifa_data.info()
RangeIndex: 5894 entries, 0 to 5893
Data columns (total 19 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 home_team 5894 non-null object
1 away_team 5894 non-null object
2 target 5894 non-null int64
3 rank_diff 5894 non-null int64
4 total_points_diff 5894 non-null float64
5 away_mean_rank 5894 non-null float64
6 away_mean_rank_last5 5894 non-null float64
7 away_mean_points 5894 non-null float64
8 away_rival_mean_points 5894 non-null float64
9 mean_goals_diff 5894 non-null float64
10 mean_goals_last5_diff 5894 non-null float64
11 rival_mean_goals_diff 5894 non-null float64
12 mean_rank_diff 5894 non-null float64
13 mean_rank_last5_diff 5894 non-null float64
14 mean_points_diff 5894 non-null float64
15 mean_points_last5_diff 5894 non-null float64
16 rival_mean_points_diff 5894 non-null float64
17 rival_mean_points_last5_diff 5894 non-null float64
18 mean_points2rank_diff 5894 non-null float64
dtypes: float64(15), int64(2), object(2)
memory usage: 875.0+ KB
模型训练
# 划分数据
X_train, X_test, y_train, y_test = train_test_split(fifa_data.iloc[:, 3:], fifa_data['target'], test_size=0.2, shuffle=True, random_state=2022)
网格搜索是一种穷举搜索方法,它通过遍历超参数的所有可能组合来寻找最优超参数。网格搜索首先为每个超参数设定一组候选值,然后生成这些候选值的笛卡尔积,形成超参数的组合网格。接着,网格搜索会对每个超参数组合进行模型训练和评估,从而找到性能最佳的超参数组合。网格搜索可以保证在指定的参数范围内找到精度最高的参数,因为网格搜索会遍历所有可能参数的组合,在面对大数据集和多参数的情况下会非常耗时。这里仅展示一个组合,如有需要请自行设置超参数候选值,例如:‘max_depth’: [3, 5, 7]。
# 网格搜索
rf_params = {
'max_depth': [10],
'max_features': ['sqrt'],
'min_samples_leaf': [10],
'min_samples_split': [10],
'n_estimators': [100]
}
rf_search = GridSearchCV(RandomForestClassifier(), rf_params, cv=3, n_jobs=-1)
rf_search.fit(X_train, y_train)
rf_search.best_params_
{'max_depth': 10,
'max_features': 'sqrt',
'min_samples_leaf': 10,
'min_samples_split': 10,
'n_estimators': 100}
随机森林是一种集成算法,它属于Bagging(个体学习器间不存在强依赖关系、可同时生成的并行化方法)类型,通过组合多个弱分类器,最终结果通过投票或取均值,使得整体模型的结果具有较高的精确度和泛化性能。其可以取得不错成绩,主要归功于“随机”和“森林”,一个使它具有抗过拟合能力,一个使它更加精准。
# 模型训练
rf = RandomForestClassifier(max_depth=10, max_features='sqrt', min_samples_leaf=10, min_samples_split=10, n_estimators=100, random_state=2022)
rf.fit(X_train, y_train)
rf_pred = rf.predict(X_test)
rf_acc = accuracy_score(y_test, rf_pred.astype('int'))
joblib.dump(rf, 'rf.pkl')
print('RandomForest Acc is: ', rf_acc)
RandomForest Acc is: 0.732824427480916
# 网格搜索
gbdt_params = {
'learning_rate': [0.01],
'max_depth': [5],
'max_features': ['sqrt'],
'min_samples_leaf': [10],
'min_samples_split': [10],
'n_estimators': [500]
}
gbdt_search = GridSearchCV(GradientBoostingClassifier(), gbdt_params, cv=3, n_jobs=-1)
gbdt_search.fit(X_train, y_train)
gbdt_search.best_params_
{'learning_rate': 0.01,
'max_depth': 5,
'max_features': 'sqrt',
'min_samples_leaf': 10,
'min_samples_split': 10,
'n_estimators': 500}
梯度提升决策树(GBDT)是一种集成算法,它属于Boosting(个体学习器间存在强依赖关系、必须串行生成的序列化方法)类型。训练时采用前向分布算法进行贪婪学习,每次迭代都学习一棵CART树来拟合之前 t-1 棵树的预测结果与训练样本真实值的残差。
# 模型训练
gbdt = GradientBoostingClassifier(learning_rate=0.01, max_depth=5, max_features='sqrt', min_samples_leaf=10, min_samples_split=10, n_estimators=500, random_state=2022)
gbdt.fit(X_train, y_train)
gbdt_pred = gbdt.predict(X_test)
gbdt_acc = accuracy_score(y_test, gbdt_pred.astype('int'))
joblib.dump(gbdt, 'gbdt.pkl')
print('GradientBoosting Acc is: ', gbdt_acc)
GradientBoosting Acc is: 0.7430025445292621
# ROC曲线和混淆矩阵
def analyze(model):
plt.figure(figsize=(15, 10))
plt.plot([0, 1], [0, 1], 'k--')
fpr_train, tpr_train, _ = roc_curve(y_train, model.predict_proba(X_train)[:, 1])
plt.plot(fpr_train, tpr_train, label='train')
fpr_test, tpr_test, _ = roc_curve(y_test, model.predict_proba(X_test)[:, 1])
plt.plot(fpr_test, tpr_test, label='test')
auc_train = roc_auc_score(y_train, model.predict_proba(X_train)[:, 1])
auc_test = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1])
plt.legend()
plt.title('AUC score is %.2f on test and %.2f on train' % (auc_test, auc_train))
plt.show()
plt.figure(figsize=(15, 10))
matrix = confusion_matrix(y_test, model.predict(X_test))
sns.heatmap(matrix, annot=True, linewidths=0.2, fmt='d')
plt.title('confusion_matrix on test')
plt.show()
# 绘制ROC曲线和混淆矩阵
analyze(rf)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ChbNPrtc-1687232982519)(main_files/main_77_0.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oNOLGNh9-1687232982519)(main_files/main_77_1.png)]
# 绘制ROC曲线和混淆矩阵
analyze(gbdt)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UKnJzL3y-1687232982520)(main_files/main_78_0.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1TvmZyyk-1687232982520)(main_files/main_78_1.png)]
2022世界杯
# 生成特征: 使用历史数据
def get_data(team):
pasts = data_copy[(data_copy['team'] == team)].sort_values(by=['date'], ascending=False)
last5 = pasts.head(5)
rank = pasts['rank'].values[0]
total_points = pasts['total_points'].values[0]
mean_rank = pasts['rank'].mean()
mean_rank_last5 = last5['rank'].mean()
mean_goals = pasts['score'].mean()
mean_goals_last5 = last5['score'].mean()
mean_points = pasts['points'].mean()
mean_points_last5 = last5['points'].mean()
mean_points2rank = pasts['points2rank'].mean()
rival_mean_goals = pasts['rival_score'].mean()
rival_mean_points = pasts['rival_points'].mean()
rival_mean_points_last5 = last5['rival_points'].mean()
return [rank, total_points, mean_rank, mean_rank_last5, mean_goals, mean_goals_last5, mean_points, mean_points_last5, mean_points2rank, rival_mean_goals, rival_mean_points, rival_mean_points_last5]
def get_feature(team1, team2):
rank_diff = team1[0] - team2[0]
total_points_diff = team1[1] - team2[1]
away_mean_rank = team2[2]
away_mean_rank_last5 = team2[3]
away_mean_points = team2[6]
away_rival_mean_points = team2[10]
mean_goals_diff = team1[4] - team2[4]
mean_goals_last5_diff = team1[5] - team2[5]
rival_mean_goals_diff = team1[9] - team2[9]
mean_rank_diff = team1[2] - team2[2]
mean_rank_last5_diff = team1[3] - team2[3]
mean_points_diff = team1[6] - team2[6]
mean_points_last5_diff = team1[7] - team2[7]
rival_mean_points_diff = team1[10] - team2[10]
rival_mean_points_last5_diff = team1[11] - team2[11]
mean_points2rank_diff = team1[8] - team2[8]
return [rank_diff, total_points_diff, away_mean_rank, away_mean_rank_last5, away_mean_points, away_rival_mean_points, mean_goals_diff, mean_goals_last5_diff, rival_mean_goals_diff, mean_rank_diff, mean_rank_last5_diff, mean_points_diff, mean_points_last5_diff, rival_mean_points_diff, rival_mean_points_last5_diff, mean_points2rank_diff]
# 读取数据
fifa_2022 = pd.read_csv('/home/aistudio/work/fifa_2022.csv', parse_dates=['date'])
fifa_2022.head()
datehome_teamaway_team02022-11-20QatarEcuador12022-11-21SenegalNetherlands22022-11-21EnglandIran32022-11-21United StatesWales42022-11-22ArgentinaSaudi Arabia
# 胜负预测
def predict(teams, model):
home = teams[0]
away = teams[1]
team1 = get_data(home)
team2 = get_data(away)
feature1 = get_feature(team1, team2)
feature2 = get_feature(team2, team1)
proba1 = model.predict_proba([feature1])
proba2 = model.predict_proba([feature2])
pred1 = (proba1[0][0] + proba2[0][1]) / 2
pred2 = (proba2[0][0] + proba1[0][1]) / 2
if pred1 < pred2:
print('%s VS %s: %s获胜 概率: %.2f' % (home, away, away, pred2))
else:
print('%s VS %s: %s获胜 概率: %.2f' % (home, away, home, pred1))
# 2022世界杯
game8 = fifa_2022.iloc[-16:-8, 1:]
game4 = fifa_2022.iloc[-8:-4, 1:]
game2 = fifa_2022.iloc[-4:-2, 1:]
game1 = fifa_2022.iloc[-2:, 1:]
team8 = []
team4 = []
team2 = []
team1 = []
for idx, row in game8.iterrows():
home_team = row['home_team']
away_team = row['away_team']
team8.append([home_team, away_team])
for idx, row in game4.iterrows():
home_team = row['home_team']
away_team = row['away_team']
team4.append([home_team, away_team])
for idx, row in game2.iterrows():
home_team = row['home_team']
away_team = row['away_team']
team2.append([home_team, away_team])
for idx, row in game1.iterrows():
home_team = row['home_team']
away_team = row['away_team']
team1.append([home_team, away_team])
# 1/8决赛
for teams in team8:
predict(teams, gbdt)
Netherlands VS United States: Netherlands获胜 概率: 0.66
Argentina VS Australia: Argentina获胜 概率: 0.83
France VS Poland: France获胜 概率: 0.71
England VS Senegal: England获胜 概率: 0.65
Japan VS Croatia: Croatia获胜 概率: 0.65
Brazil VS South Korea: Brazil获胜 概率: 0.82
Morocco VS Spain: Spain获胜 概率: 0.81
Portugal VS Switzerland: Portugal获胜 概率: 0.52
# 1/4决赛
for teams in team4:
predict(teams, gbdt)
Croatia VS Brazil: Brazil获胜 概率: 0.70
Netherlands VS Argentina: Argentina获胜 概率: 0.60
Morocco VS Portugal: Portugal获胜 概率: 0.76
England VS France: France获胜 概率: 0.55
# 半决赛
for teams in team2:
predict(teams, gbdt)
Argentina VS Croatia: Argentina获胜 概率: 0.64
France VS Morocco: France获胜 概率: 0.73
# 决赛
for teams in team1:
predict(teams, gbdt)
Croatia VS Morocco: Croatia获胜 概率: 0.67
Argentina VS France: Argentina获胜 概率: 0.55
VS Morocco: France获胜 概率: 0.73
# 决赛
for teams in team1:
predict(teams, gbdt)
Croatia VS Morocco: Croatia获胜 概率: 0.67
Argentina VS France: Argentina获胜 概率: 0.55
总结
2022年12月19日,2022年卡塔尔世界杯决赛,阿根廷队在点球大战中战胜法国队,获得冠军。
项目以学习为目的,旨在体验特征工程。优化:数据扩充、数据粒度、特征工程、模型构建。
致谢
Predicting FIFA 2022 World Cup with ML
此文章为搬运 原项目链接
