天天看点

java 交叉验证CrossValidation 完整版设计

一、认识

交叉验证(Cross-Validation): 有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法。于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证。 一开始的子集被称为训练集。而其它的子集则被称为验证集或测试集。WIKI

交叉验证对于人工智能,机器学习,模式识别,分类器等研究都具有很强的指导与验证意义。

基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训练集(train set),另一部分做为验证集(validation set or test set),首先用训练集对分类器进行训练,在利用验证集来测试训练得到的模型(model),以此来做为评价分类器的性能指标.

二、设计

package recomendation;

//交叉验证
public class CrossValidation {
	
	/**
     * The number of rounds of cross validation.交叉验证的轮数。
     */
    public final int k;
    /**
     * The index of training instances.训练实例的索引。
     */
    public final int[][] train;
    /**
     * The index of testing instances.
     */
    public final int[][] test;

    /**
     * Constructor.构造函数。
     * @param n the number of samples.样本数。
     * @param k the number of rounds of cross validation.交叉验证的轮数
     */
    public CrossValidation(int n, int k) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);//样本数量无效
        }

        if (k < 0 || k > n) {
            throw new IllegalArgumentException("Invalid number of CV rounds: " + k);//无效
        }

        this.k = k;
        
        int[] index = new int[n];

        // insert integers 0..n-1
        for (int i = 0; i < n; i++)
            index[i] = i;

        // shuffle  ,to create permutation of array随机,以创建数组的排列
        for (int i = 0; i < n; i++) {
            int r = (int) (Math.random() * (i+1));     // int between 0 and i  //0至i之间的int
            int swap = index[r];
            index[r] = index[i];
            index[i] = swap;
        }
        
        train = new int[k][];//训练集
        test = new int[k][];//测试集

        int chunk = n / k;
        for (int i = 0; i < k; i++) {
            int start = chunk * i;
            int end = chunk * (i + 1);
            if (i == k-1) end = n;

            train[i] = new int[n - end + start];
            test[i] = new int[end - start];
            for (int j = 0, p = 0, q = 0; j < n; j++) {
                if (j >= start && j < end) {
                    test[i][p++] = index[j];
                } else {
                    train[i][q++] = index[j];
                }
            }
        }
    }
}