练习
用迭代式MapReduce实现SPSS(单源最短路径算法)。参考资料如下:
测试用例:
[0,0,[[1,0]]]
[1,0,[[2,100]]]
[2,100,[[3,200]]]
[3,300,[[4,300]]]
[4,600,[[5,400]]]
[5,1000,[[6,500]]]
[6,1500,[[7,600]]]
[7,2100,[[8,700]]]
[8,2800,[[9,800]]]
[9,3600,[[10,900]]]
[10,4500,[[11,1000]]]
[11,5500,[[12,1100]]]
[12,6600,[[13,1200]]]
[13,7800,[[14,1300]]]
[14,9100,[[0,1400]]]
答案:
package org.apache.giraph.examples;
import org.apache.giraph.graph.*;
import org.apache.giraph.lib.TextVertexInputFormat;
import org.apache.giraph.lib.TextVertexInputFormat.TextVertexReader;
import org.apache.giraph.lib.TextVertexOutputFormat;
import org.apache.giraph.lib.TextVertexOutputFormat.TextVertexWriter;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Logger;
import org.json.JSONArray;
import org.json.JSONException;
import java.io.IOException;
import java.util.Iterator;
public class SimpleShortestPathsVertex extends
Vertex<LongWritable, DoubleWritable, FloatWritable, DoubleWritable>
implements Tool {
private Configuration conf;
// 日志
private static final Logger LOG =
Logger.getLogger(SimpleShortestPathsVertex.class);
//最短路径id
public static String SOURCE_ID = "SimpleShortestPathsVertex.sourceId";
public static long SOURCE_ID_DEFAULT = 1;
/**
* Is this vertex the source id?
*
* @return True if the source id
*/
private boolean isSource() {
return (getVertexId().get() ==
getContext().getConfiguration().getLong(SOURCE_ID,
SOURCE_ID_DEFAULT));
}
@Override
//计算当前迭代下的最短路径
public void compute(Iterator<DoubleWritable> msgIterator) {
if (getSuperstep() == 0) {
setVertexValue(new DoubleWritable(Double.MAX_VALUE));
}
double minDist = isSource() ? 0d : Double.MAX_VALUE;
while (msgIterator.hasNext()) {
minDist = Math.min(minDist, msgIterator.next().get());
}
if (LOG.isDebugEnabled()) {
LOG.debug("Vertex " + getVertexId() + " got minDist = " + minDist +
" vertex value = " + getVertexValue());
}
if (minDist < getVertexValue().get()) {
setVertexValue(new DoubleWritable(minDist));
for (LongWritable targetVertexId : this) {
FloatWritable edgeValue = getEdgeValue(targetVertexId);
if (LOG.isDebugEnabled()) {
LOG.debug("Vertex " + getVertexId() + " sent to " +
targetVertexId + " = " +
(minDist + edgeValue.get()));
}
sendMsg(targetVertexId,
new DoubleWritable(minDist + edgeValue.get()));
}
}
voteToHalt();
}
//格式化输入的顶点
public static class SimpleShortestPathsVertexInputFormat extends
TextVertexInputFormat<LongWritable, DoubleWritable, FloatWritable> {
@Override
public VertexReader<LongWritable, DoubleWritable, FloatWritable>
createVertexReader(InputSplit split,
TaskAttemptContext context)
throws IOException {
return new SimpleShortestPathsVertexReader(
textInputFormat.createRecordReader(split, context));
}
}
//读出所有顶点,并赋值给vertex,初始化vertex的id,值和边
public static class SimpleShortestPathsVertexReader extends
TextVertexReader<LongWritable, DoubleWritable, FloatWritable> {
public SimpleShortestPathsVertexReader(
RecordReader<LongWritable, Text> lineRecordReader) {
super(lineRecordReader);
}
@Override
public boolean next(MutableVertex<LongWritable,
DoubleWritable, FloatWritable, ?> vertex)
throws IOException, InterruptedException {
if (!getRecordReader().nextKeyValue()) {
return false;
}
Text line = getRecordReader().getCurrentValue();
try {
JSONArray jsonVertex = new JSONArray(line.toString());
vertex.setVertexId(
new LongWritable(jsonVertex.getLong(0)));
vertex.setVertexValue(
new DoubleWritable(jsonVertex.getDouble(1)));
JSONArray jsonEdgeArray = jsonVertex.getJSONArray(2);
for (int i = 0; i < jsonEdgeArray.length(); ++i) {
JSONArray jsonEdge = jsonEdgeArray.getJSONArray(i);
vertex.addEdge(new LongWritable(jsonEdge.getLong(0)),
new FloatWritable((float) jsonEdge.getDouble(1)));
}
} catch (JSONException e) {
throw new IllegalArgumentException(
"next: Couldn't get vertex from line " + line, e);
}
return true;
}
}
//格式化输出的顶点
public static class SimpleShortestPathsVertexOutputFormat extends
TextVertexOutputFormat<LongWritable, DoubleWritable,
FloatWritable> {
@Override
public VertexWriter<LongWritable, DoubleWritable, FloatWritable>
createVertexWriter(TaskAttemptContext context)
throws IOException, InterruptedException {
RecordWriter<Text, Text> recordWriter =
textOutputFormat.getRecordWriter(context);
return new SimpleShortestPathsVertexWriter(recordWriter);
}
}
//vertex的writer,计算出的vertex值依次存入文件
public static class SimpleShortestPathsVertexWriter extends
TextVertexWriter<LongWritable, DoubleWritable, FloatWritable> {
public SimpleShortestPathsVertexWriter(
RecordWriter<Text, Text> lineRecordWriter) {
super(lineRecordWriter);
}
@Override
public void writeVertex(BasicVertex<LongWritable, DoubleWritable,
FloatWritable, ?> vertex)
throws IOException, InterruptedException {
JSONArray jsonVertex = new JSONArray();
try {
jsonVertex.put(vertex.getVertexId().get());
jsonVertex.put(vertex.getVertexValue().get());
JSONArray jsonEdgeArray = new JSONArray();
for (LongWritable targetVertexId : vertex) {
JSONArray jsonEdge = new JSONArray();
jsonEdge.put(targetVertexId.get());
jsonEdge.put(vertex.getEdgeValue(targetVertexId).get());
jsonEdgeArray.put(jsonEdge);
}
jsonVertex.put(jsonEdgeArray);
} catch (JSONException e) {
throw new IllegalArgumentException(
"writeVertex: Couldn't write vertex " + vertex);
}
getRecordWriter().write(new Text(jsonVertex.toString()), null);
}
}
@Override
public Configuration getConf() {
return conf;
}
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public int run(String[] argArray) throws Exception {
/*如果输入参数长度不为4,则报错
"run: Must have 4 arguments <input path> <output path> " +
"<source vertex id> <# of workers>");
}*/
//设置作业
GiraphJob job = new GiraphJob(getConf(), getClass().getName());
job.setVertexClass(getClass());
job.setVertexInputFormatClass(
SimpleShortestPathsVertexInputFormat.class);
job.setVertexOutputFormatClass(
SimpleShortestPathsVertexOutputFormat.class);
//设置输入输出路径
FileInputFormat.addInputPath(job, new Path("hdfs:localhost:9000/shortestPathsInputGraph"));
FileOutputFormat.setOutputPath(job, new Path("hdfs:localhost:9000/SSSPOutputVertex"));
job.getConfiguration().setLong(SimpleShortestPathsVertex.SOURCE_ID,1);
job.getConfiguration().setBoolean("giraph.SplitMasterWorker", false);
job.setWorkerConfiguration(1,1,100.0f);
if (job.run(true) == true) {
return 0;
} else {
return -1;
}
}
public static void main(String[] args) throws Exception {
System.exit(ToolRunner.run(new SimpleShortestPathsVertex(), args));
}
}
输出:
[0,10500,[[1,0]]]
[1,0,[[2,100]]]
[2,100,[[3,200]]]
[3,300,[[4,300]]]
[4,600,[[5,400]]]
[5,1000,[[6,500]]]
[6,1500,[[7,600]]]
[7,2100,[[8,700]]]
[8,2800,[[9,800]]]
[9,3600,[[10,900]]]
[10,4500,[[11,1000]]]
[11,5500,[[12,1100]]]
[12,6600,[[13,1200]]]
[13,7800,[[14,1300]]]
[14,9100,[[0,1400]]]