网站建设cms系统/北京计算机培训机构前十名
由于工作中要用到SVR算法,项目组的系统是用java开发的,因此,为了能与项目组同步,算法需要用java来实现,还好台湾大学的林智仁教授推出了Libsvm的源代码,包括java、c++等语言的源代码,在此表示感谢!因此,算法的主体部分不用自己开发了,在源代码的基础上做一些修改就能够应用到自己的项目中了,开源真好!受益了无数人。。。为了弘扬开源的精神,开博记录学习Libsvm java版源代码的过程。下面正式开始!先从SVR回归算法的代码开始,然后逐步扩展到分类算法。希望自己能够坚持下去。加油。
---------------------我是华丽的分割线-----------高手的小jj,我割割割--嘻嘻-------------------
一、初识Libsvm
Libsvm java版本的源代码很容易下载到,为了使算法能够运行,要先把源代码复制到myeclipse中(貌似是废话),在这里,提供一个链接,里面有很好的说明,按照文档中的说明就能够将Libsvm运行起来。
按照上述链接中的方式,自己新建一个main函数,来调用Libsvm算法的源代码,代码如下:
public static void main(String[] args) throws IOException {
String []arg ={ "trainfile\\train1.txt", //存放SVM训练模型用的数据的路径
"trainfile\\model_r.txt"}; //存放SVM通过训练数据训练出来的模型的路径
String []parg={"trainfile\\test2.txt", //这个是存放测试数据
"trainfile\\model_r.txt", //调用的是训练以后的模型
"trainfile\\out_r.txt"}; //生成的结果的文件的路径
System.out.println("........SVM运行开始..........");
//创建一个训练对象
svm_train t = new svm_train();
//创建一个预测或者分类的对象
svm_predict p= new svm_predict();
t.main(arg); //调用
p.main(parg); //调用
}
注意:
1. 该主函数是为了调用Libsvm源代码的,创建了两个字符数组:arg[]和parg[]。
其中arg[]数组存放了两个字符串,"trainfile\\train1.txt"和"trainfile\\model_r.txt",这两个字符串传入到Libsvm中的svm_train();
在trainfile\\train1.txt这个文件里面存储的是训练数据,训练数据的格式如下:
Y index1:特征1 index2:特征2 ....
1:特征1 2:特征2 3:特征3 ...
1:特征1 2:特征2 3:特征3 ...
其中如果是分类问题,lable为类标签。如果是回归问题,lable为具体的实数。为了生成这种格式的数据,可以采用FormatDataLibsvm.exl这个excel文件生成,网上可以下的到。当然,也可以自己写java代码,来生成这种格式的数据。代码随后奉上。而trainfile\\model_r.txt这个文件的作用是存储利用SVM算法训练好的模型。
而parg[]这个数组存放了3个字符串:trainfile\\test2.txt---用于存储测试样本的文件,数据格式与训练样本的格式一样、trainfile\\model_r.txt---前面生成的训练好的模型,在用新样本进行预测时直接使用前文训练好的模型即可、trainfile\\model_r.txt---用于存储模型的预测值,分类问题的话存储的是预测样本每条样本所属的类别,而回归问题的话,存储的是每条样本所对应的预测值。
2. svm_train()与svm_predict()。这两个函数是Libsvm程序包中的源代码,由此即进入到了Libsvm源代码的世界了。
二、Libsvm的真面目
由于本篇是介绍支持向量回归机--SVR的,所以仅从用到SVR算法的代码入手,来分析Libsvm的源代码。分析时采用逐层进入的方式。下面直接上程序代码目录:
+++++++目录+++++
前文是首先调用svm_train()然后调用svm_predict()。那就先从svm_train()说起。
class svm_train{}包含以下几个变量以及函数:
private svm_parameter param; // 用于设置svm模型的参数
private svm_problem prob; // 用来存储样本序号、样本的目标变量Y、样本自变量X.详看class svm_problem
private svm_model model;// ??
private String input_file_name; // 输入文件名
private String model_file_name; // 模型文件名
private String error_msg;//错误信息
private int cross_validation;//交叉验证
private int nr_fold;// ??
private static svm_print_interface svm_print_null = new svm_print_interface() //
private static void exit_with_help() //打印帮助信息
private void do_cross_validation() //交叉验证
private void run(String argv[]) throws IOException //运行svm训练程序
public static void main(String argv[]) throws IOException // 主函数
private static double atof(String s) //将字符串转化为浮点型
private static int atoi(String s) //将字符串转化为int型
private void parse_command_line(String argv[]) //设置参数
private void read_problem() throws IOException //读取错误信息
前文自己写的那个函数,调用的svm训练模型,即
svm_train t = new svm_train();
t.main(arg); //调用svm训练模型
程序进入到class svm_train{ }中的main函数,即 public static void main(String argv[]) throws IOException 该主函数的代码如下:
public static void main(String argv[]) throws IOException {
svm_train t = new svm_train();
t.run(argv);//传进来一个数组,数组里面有两个字符串,一个是训练样本.txt,一个是训练好的模型.txt
}
继续执行run()函数,输入为数组argv[].
private void run(String argv[]) throws IOException{
parse_command_line(argv); // 1.进入到该函数中,获取SVM参数
read_problem(); // 2.进入到该函数中,读取错误信息
error_msg = svm.svm_check_parameter(prob,param); // 3.检查参数
//检查参数,有错误则返回各种参数错误信息,无错误则返回null;
if(error_msg != null)
{
System.err.print("ERROR: "+error_msg+"\n");
System.exit(1);
}
if(cross_validation != 0)
{
do_cross_validation(); // 4.交叉验证
}
else
{
model = svm.svm_train(prob,param); // 5.prob--训练样本,param--SVM模型参数
svm.svm_save_model(model_file_name,model); // 6.保存训练好的模型
}
}
注:该函数一共调用了6个函数,下文一一说明。
首先进入函数1:parse_command_line(argv); // 1.进入到该函数中,获取SVM参数。该函数的输入为argv[],即两个字符串:一个是训练样本.txt,一个是训练好的模型.txt。该函数虽然无返回值,但在函数里面,已经将svm的一些参数存储在param中了,详细参数名称见class svm_parameter,因此模型训练时已经有了所需要的各种参数。函数的详细代码如下:
private void _lineparse_command(String argv[])
{
int i;
svm_print_interface print_func = null; // default printing to stdout
param = new svm_parameter();//开始设置SVM模型的各种参数
// default values
//param.svm_type = svm_parameter.C_SVC;
param.svm_type = svm_parameter.EPSILON_SVR; //此时运行的是SVR算法
param.kernel_type = svm_parameter.RBF; //核函数取径向基核函数
param.degree = 3; //??
param.gamma = 0.08;
//gamma为RBF核函数的参数,默认时=1/num_features 此时设置为0.08 gamma=1/2*sig^2 sig=2.5
//RBF核函数:exp(-gamma*|Xi-Xj|^2)
param.coef0 = 0; //??
param.nu = 0.5; //??
param.cache_size = 100; //设置缓存的大小
param.C = 100; //惩罚参数
param.eps = 0.005; //??
param.p = 0.001; //此值为EPSILON_SVR中EPSILON
param.shrinking = 1; //??
param.probability = 0; //概率估计??
param.nr_weight = 0; //权重??
param.weight_label = new int[0]; //??
param.weight = new double[0];
cross_validation = 0; //交叉验证。0--不进行交叉验证。1--交叉验证
//获取输入参数
// parse options
for(i=0;i
{
if(argv[i].charAt(0) != '-') break;
//由于第一个字符(trainfile\\train1.txt)中的第一个字符不是‘-’,果断break!退出for循环。i=0
if(++i>=argv.length)
exit_with_help();
switch(argv[i-1].charAt(1))
{
case 's':
param.svm_type = atoi(argv[i]);
break;
case 't':
param.kernel_type = atoi(argv[i]);
break;
case 'd':
param.degree = atoi(argv[i]);
break;
case 'g':
param.gamma = atof(argv[i]);
break;
case 'r':
param.coef0 = atof(argv[i]);
break;
case 'n':
param.nu = atof(argv[i]);
break;
case 'm':
param.cache_size = atof(argv[i]);
break;
case 'c':
param.C = atof(argv[i]);
break;
case 'e':
param.eps = atof(argv[i]);
break;
case 'p':
param.p = atof(argv[i]);
break;
case 'h':
param.shrinking = atoi(argv[i]);
break;
case 'b':
param.probability = atoi(argv[i]);
break;
case 'q':
print_func = svm_print_null;
i--;
break;
case 'v':
cross_validation = 1;
nr_fold = atoi(argv[i]);
if(nr_fold < 2)
{
System.err.print("n-fold cross validation: n must >= 2\n");
exit_with_help();
}
break;
case 'w':
++param.nr_weight;
{
int[] old = param.weight_label;
param.weight_label = new int[param.nr_weight];
System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
}
{
double[] old = param.weight;
param.weight = new double[param.nr_weight];
System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
}
param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
param.weight[param.nr_weight-1] = atof(argv[i]);
break;
default:
System.err.print("Unknown option: " + argv[i-1] + "\n");
exit_with_help();
} //switch循环结束
}//for循环结束
svm.svm_set_print_string_function(print_func); //1.1打印,详见下文说明2;
// determine filenames
if(i>=argv.length) //argv.length=2,而i=0,不执行此语句
exit_with_help();
input_file_name = argv[i]; //训练样本的文件名,即trainfile\\data_train_svr.txt
if(i
model_file_name = argv[i+1]; //模型文件名,即trainfile\\model_r.txt
else//此时不执行下面语句
{
int p = argv[i].lastIndexOf('/');
++p; // whew...
model_file_name = argv[i].substring(p)+".model";
}
}//函数parse_command_line结束
说明:
1.该函数的功能是初始化svm模型的各种参数,本篇用的是SVR算法,初始化了一些参数。
2.该函数调用了一个函数,即函数1.1,由于该函数的输入是print_func = null,经过调用,其输出也为空,即不打印任何信息,因此本文不予深入说明。
此时程序进入函数2:read_problem() ; // 2.进入到该函数中,读取错误信息
该函数无返回值,但在函数体内针对两个错误,用打印输出语句打印出相应的错误信息,其中一个错误信息为:核函数的第一列的标签必须从0开始编号。如果不是从0编号,则打印输出此错误信息。第二个错误为:样本格式有错误,如果样本的编号标签小于0或者样本的编号标签值大于样本的实际个数,则打印输出该错误信息。
private void read_problem() throws IOException {
BufferedReader fp = new BufferedReader(new FileReader(input_file_name));
Vector vy = new Vector();
Vector vx = new Vector();
int max_index = 0;
while(true)
{
String line = fp.readLine();
if(line == null) break;
StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
vy.addElement(atof(st.nextToken()));//atof--将字符串转化为数字
int m = st.countTokens()/2;//训练样本的特征个数,Y,X1,X2...Xm-1.共m个
svm_node[] x = new svm_node[m];
for(int j=0;j
{
x[j] = new svm_node();
x[j].index = atoi(st.nextToken());
x[j].value = atof(st.nextToken());
}
if(m>0) max_index = Math.max(max_index, x[m-1].index);
vx.addElement(x);
}
prob = new svm_problem();
prob.l = vy.size();
prob.x = new svm_node[prob.l][];
for(int i=0;i
prob.x[i] = vx.elementAt(i);
prob.y = new double[prob.l];
for(int i=0;i
prob.y[i] = vy.elementAt(i);
if(param.gamma == 0 && max_index > 0)
param.gamma = 1.0/max_index;
if(param.kernel_type == svm_parameter.PRECOMPUTED)
for(int i=0;i
{
if (prob.x[i][0].index != 0)
{
System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");
System.exit(1);
}
if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
{
System.err.print("Wrong input format: sample_serial_number out of range\n");
System.exit(1);
}
}
fp.close();
}
}
此时程序进入函数3:error_msg = svm.svm_check_parameter(prob,param); // 3.检查参数
该函数是在class svm中,功能是检查svm模型的参数是否正确。
public static String svm_check_parameter(svm_problem prob, svm_parameter param)
该函数的输入是svm_problem prob和svm_parameter param两个类,其中类svm_problem如下:
public class svm_problem implements java.io.Serializable{
public int l;//训练样本中,样本的标签,即第l个训练样本
public double[] y;//训练样本的目标变量Y
public svm_node[][] x;//训练样本的自变量X
}svm_parameter则是svm所需要的各种参数。
输出则是一个字符串,如果匹配到相应的错误,则输出其错误信息,如果没有错误,则返回NULL.
此时程序进入函数4:do_cross_validation(); // 4.交叉验证
由于SVR算法不需要交叉验证,故不执行此函数。而对于分类而言,执行交叉验证操作可增强算法的推广能力。
这里留作以后详细研究。
此时程序进入函数5:model = svm.svm_train(prob,param); //prob--训练样本,param--SVM模型参数
此时进入到了SVM/SVR算法的关键环节--训练模型。欲知详情,请听下回分解。
-----格格-----------------
参考文献:
1.http://wenku.baidu.com/view/54cfa92b453610661ed9f4f6.html-----很好的libsvm安装使用说明