StratifiedKFold is a kind of advanced KFold, the Normal KFold process is as below as we know.
x_data = np.array([[1,2],[2,4],[3,2],[4,4],[5,4],[6,2],[7,4],[8,4],[9,2],[10,4],[11,2],[12,4],[13,2],[14,4],[15,4] ,[16,2],[17,4],[18,4],[19,2],[20,4]]) y_data = np.array([1,2,3,4,5,1,2,3,4,5,1,2,3,4,5,1,2,3,4,5]) kf = KFold(n_splits=10) for train_index, test_index in kf.split(x_data): x_train = x_data[train_index] y_train = y_data[train_index] x_test = x_data[test_index] t_text = y_data[test_index] print ("train:{0} test:{1}".format(train_index,test_index)) # train model # test model # get error rate
train:[ 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] test:[0 1] train:[ 0 1 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] test:[2 3] train:[ 0 1 2 3 6 7 8 9 10 11 12 13 14 15 16 17 18 19] test:[4 5] train:[ 0 1 2 3 4 5 8 9 10 11 12 13 14 15 16 17 18 19] test:[6 7] train:[ 0 1 2 3 4 5 6 7 10 11 12 13 14 15 16 17 18 19] test:[8 9] train:[ 0 1 2 3 4 5 6 7 8 9 12 13 14 15 16 17 18 19] test:[10 11] train:[ 0 1 2 3 4 5 6 7 8 9 10 11 14 15 16 17 18 19] test:[12 13] train:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 16 17 18 19] test:[14 15] train:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 18 19] test:[16 17] train:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17] test:[18 19]
Well , you may notice that the test data don't give a complete categorization . only 2.
However for the stratifiedKFold, it's as below:
x_data = np.array([[1,2],[2,4],[3,2],[4,4],[5,4],[6,2],[7,4],[8,4],[9,2],[10,4],[11,2],[12,4],[13,2],[14,4],[15,4] ,[16,2],[17,4],[18,4],[19,2],[20,4]]) y_data = np.array([1,2,3,4,5,1,2,3,4,5,1,2,3,4,5,1,2,3,4,5]) # in each fold , test set will contain 1,2,3,4,5 completed data, so K is set maximumly to 4 here kf = StratifiedKFold(n_splits=4) for train_index, test_index in kf.split(x_data,y_data): x_train = x_data[train_index] y_train = y_data[train_index] x_test = x_data[test_index] t_text = y_data[test_index] print ("train:{0} test:{1}".format(train_index,test_index)) # train model # test model # get error rate
train:[ 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] test:[0 1 2 3 4] train:[ 0 1 2 3 4 10 11 12 13 14 15 16 17 18 19] test:[5 6 7 8 9] train:[ 0 1 2 3 4 5 6 7 8 9 15 16 17 18 19] test:[10 11 12 13 14] train:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14] test:[15 16 17 18 19]
Now, It's very comprehensive. So normally we choose StratifiedKFold rather than Normal KFold.
No comments:
Post a Comment