首页 新闻 会员 周边

用hyperopt自动调参出现错误,求大神帮忙解决

0
悬赏园豆:100 [待解决问题]
#! _*_coding: utf-8 _*_

import pandas as pd
from xgboost.sklearn import XGBClassifier as xgb
from hyperopt import hp, tpe, fmin
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


params = {
'max_depth': hp.randint('max_depth', 3, 10),
'colsample': hp.uniform('colsample', 0.5, 1),
'learning_rate': hp.uniform('learning_rate', 0.01, 0.2)
}

data = pd.read_csv('./train.csv', encoding='utf-8')
x = data.iloc[:, 2:-1]
x = StandardScaler().fit_transform(x)
y = data.iloc[:, -1:]
train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.3, random_state=33)
count = 0


def function(args):
print(args)
clf = xgb(**args)
# 训练模型
clf.fit(train_x, train_y)
# 预测测试集
prediction = clf.predict(test_x)

global count
count = count + 1
score = accuracy_score(test_y, prediction)
print("第%d次,测试集正确率为:" % count, score)
return -score


best = fmin(function, params, algo=tpe.suggest, max_evals=10)

print('最佳参数:', best)
clf = xgb(**best)
print(clf)

报错如下:

Traceback (most recent call last):
File "hyperopt_xgb.py", line 41, in <module>
best = fmin(function, params, algo=tpe.suggest, max_evals=10)
File "/usr/local/lib/python3.5/dist-packages/hyperopt/fmin.py", line 320, in fmin
rval.exhaust()
File "/usr/local/lib/python3.5/dist-packages/hyperopt/fmin.py", line 199, in exhaust
self.run(self.max_evals - n_done, block_until_done=self.async)
File "/usr/local/lib/python3.5/dist-packages/hyperopt/fmin.py", line 157, in run
self.rstate.randint(2 ** 31 - 1))
File "/usr/local/lib/python3.5/dist-packages/hyperopt/tpe.py", line 812, in suggest
= tpe_transform(domain, prior_weight, gamma)
File "/usr/local/lib/python3.5/dist-packages/hyperopt/tpe.py", line 793, in tpe_transform
s_prior_weight
File "/usr/local/lib/python3.5/dist-packages/hyperopt/tpe.py", line 684, in build_posterior
b_post = fn(*b_args, **dict(named_args))
TypeError: ap_categorical_sampler() got multiple values for argument 'size'




王小伍sky的主页 王小伍sky | 初学一级 | 园豆:8
提问于:2018-08-21 09:17
< >
分享
所有回答(1)
0

'max_depth': hp.randint('max_depth', 3, 10),这个 hp.randint只接受两个参数,下限默认就是0

uhyuuhyu | 园豆:202 (菜鸟二级) | 2019-08-14 10:48
清除回答草稿
   您需要登录以后才能回答,未注册用户请先注册