import matplotlib.pyplot as plt
import pyecharts.charts as pyc
import pyecharts.options as opts
import pyecharts.globals as glbs
from sklearn import preprocessing as ppcs
from sklearn import linear_model
from sklearn import metrics
from sklearn import model_selection
from city_grouping import render, write_js
from brands_clustering import shops, plot_features, features_scaling, get_feat_dict
economy = pd.read_csv('economy.csv')
cities = shops[['city', 'avgprice', 'avgscore', 'comments', 'latitude', 'longitude']]
cities = cities.groupby('city').mean()
cities['GDP'] = cities.index.map(dict(zip(economy.city, economy['GDP(2022-Q1-Q2)'])))
cities['Pop'] = cities.index.map(dict(zip(economy.city, economy['population(2019)'])))
cities['counts'] = cities.index.map(shops.groupby('city')['id'].count())
cities = cities.sort_values(by='counts', ascending=False)
def plot_corr_matrix(features=cities, rename=True):
df.columns = ['人均消费', '评分均值', '评论均值', '经度中心', '纬度中心',
value = [[i, j, round(cm[x][y], 4)] for i, x in enumerate(cm.index) for j, y in enumerate(cm.columns)]
init_opts=opts.InitOpts(theme=glbs.ThemeType.DARK, width="360px", height="360px", bg_color='#1a1c1d')
).add_xaxis(list(cm.index)
).add_yaxis('相关系数', list(cm.columns), value
label_opts=opts.LabelOpts(is_show=False)
legend_opts=opts.LegendOpts(is_show=False),
xaxis_opts=opts.AxisOpts(axislabel_opts=opts.LabelOpts(rotate=90)),
visualmap_opts=opts.VisualMapOpts(
min_=-1, max_=1, precision=2,
range_color=["#0000FF", "#FFFFFF", "#B50A24"],
orient='horizontal', pos_top='1%', pos_left='center',
item_width='12px', item_height='200px'
grid_option = '"grid": {"x": 80, "y": 50, "x2": 30, "y2": 80},\n '
write_js('temp.html', '"visualMap": {', grid_option + '"visualMap": {')
os.rename('temp.html', 'corr.html')
def plot_predict(model, features, labels, title=''):
pred = model.predict(features)
residual = - (labels - pred)
res_up = residual.map(lambda x: x if x >= 0 else None)
res_dw = residual.map(lambda x: x if x < 0 else None)
init_options = opts.InitOpts(theme=glbs.ThemeType.DARK, bg_color='#1a1c1d', width='100%', height='360px')
line = pyc.Line(init_opts=init_options
).add_xaxis(list(labels.index)
).add_yaxis('truth', [round(l, 2) for l in labels], symbol_size=6
'prediction', [round(p, 2) for p in pred],
itemstyle_opts=opts.ItemStyleOpts(color='#A35300')
).set_series_opts(label_opts=opts.LabelOpts(is_show=False)
).set_global_opts(title_opts=opts.TitleOpts(title=title))
bar = pyc.Bar(init_opts=init_options
).add_xaxis(list(labels.index)
'residual+', [round(r, 2) for r in res_up], bar_width='60%',
itemstyle_opts=opts.ItemStyleOpts(color='#DA6964')
'residual-', [round(r, 2) for r in res_dw], bar_width='60%',
itemstyle_opts=opts.ItemStyleOpts(color='#6F9F71')
).set_series_opts(label_opts=opts.LabelOpts(is_show=False)
os.remove(f'{title}.html')
line.render(f'{title}.html')
def print_coefficients(model):
text = f'y = {round(model.intercept_, 4)}'
for i, c in enumerate(model.coef_):
text += f' + {round(c, 4)} · X{i+1}'
print('\n'+'-'*60+'\n' + text + '\n'+'-'*60+'\n')
if __name__ == "__main__":
selected = ['GDP', 'Pop', 'latitude', 'longitude']
features = cities[selected]
plot_corr_matrix(cities[selected+[LABEL]], False)
lr = linear_model.LinearRegression()
plot_predict(lr, features, labels, 'Linear')
print(lr.score(features, labels))
poly_feat = ppcs.PolynomialFeatures(degree=2, include_bias=False).fit_transform(features)
df = pd.DataFrame({'x1': [3], 'x2':[5], 'x3':[7], 'x4':[11]})
print(ppcs.PolynomialFeatures(degree=2, include_bias=False).fit_transform(df))
poly_lr = linear_model.LinearRegression()
poly_lr.fit(poly_feat, labels)
plot_predict(poly_lr, poly_feat, labels, LABEL)
print_coefficients(poly_lr)
print(poly_lr.score(poly_feat, labels))
for train, test in model_selection.KFold(5).split(poly_feat):
plr = linear_model.LinearRegression()
plr.fit(poly_feat[train], labels[train])
s = metrics.r2_score(labels[test], plr.predict(poly_feat[test]))
'columns': list(features.columns),
'intercept': float(poly_lr.intercept_),
'coefs': [float(c) for c in poly_lr.coef_]
with open(LABEL+'.json', 'w', encoding='utf-8') as f:
f.write(str(model).replace(',', ',\n').replace("'", '"'))