Skip to content

Instantly share code, notes, and snippets.

@bbeaudreault
Last active November 8, 2017 06:28
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save bbeaudreault/9788499 to your computer and use it in GitHub Desktop.
Save bbeaudreault/9788499 to your computer and use it in GitHub Desktop.
Control # of mappers per RegionServer

This subclass of TableInputFormat allows two additional functionalities:

  • Use a single job to run against multiple scans (of the same HTable)
  • Control the # of mappers executing agaisnt each RegionServer

Background:

Let's say you have 20 RegionServers, and a table with 200 regions. You want to do a full table scan with mapreduce, and execute your job on a cluster that can run 10 mappers at once.

Depending on how the regions get partitioned to mappers, or how fast certain mappers are, chances are at some point you may have all 10 map slots executing against the same regionserver. This is not ideal from a load balancing standpoint, and may result in failures if your RS can not handle the memory overhead of 10 large scans at once. You may also tie up all RPCs with your long scans.

Using this MultipleScanTableInputFormat, you can use MultipleScanTableInputFormat.PARTITIONS_PER_REGION_SERVER configuration to control how many mappers should execute against a single regionserver. The class will group all the input splits by their location (regionserver), and the RecordReader will properly iterate through all aggregated splits for the mapper.

Note: The above example is contrived, but this helped us quite a bit in cases where we might have 30-40 mappers hitting a single regionserver for a single job (with multiple jobs running at once).

package org.apache.hadoop.hbase.mapreduce;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.mapreduce.Job;
public class HBaseMapReduceUtils {
private static final String CREATE_DEPENDENCY_JAR_KEY = "mapreduce.create.dependency.jar";
public static void initHbaseInputMapReduceJob(Job job, String unversionedTableName) {
initHBaseInputMapReduceJob(job, unversionedTableName, createDefaultScan());
}
public static void initHBaseInputMapReduceJob(Job job, String unversionedTableName, Scan scan) {
initHBaseInputMapReduceJob(job, unversionedTableName, scan, false);
}
public static void initHBaseInputMapReduceJob(Job job, String unversionedTableName, Scan scan, boolean partitionByRegionserver) {
initMultiScanHBaseInputMapReduceJob(job, unversionedTableName, scan, Collections.singletonList(scan.getStartRow()), Collections.singletonList(scan.getStopRow()), partitionByRegionserver);
}
public static void initMultiScanHBaseInputMapReduceJob(Job job, String unversionedTableName, Scan scan, List<byte[]> startRows, List<byte[]> stopRows, boolean partitionByRegionServer) {
HBaseConfiguration.addHbaseResources(job.getConfiguration());
job.setInputFormatClass(MultipleScanTableInputFormat.class);
job.getConfiguration().set(TableInputFormat.INPUT_TABLE, unversionedTableName);
job.getConfiguration().setBoolean(MultipleScanTableInputFormat.PARTITION_BY_REGION_SERVER, partitionByRegionServer);
try {
job.getConfiguration().set(TableInputFormat.SCAN, TableMapReduceUtil.convertScanToString(scan));
job.getConfiguration().set(MultipleScanTableInputFormat.START_ROWS, MultipleScanTableInputFormat.convertRowsToString(startRows));
job.getConfiguration().set(MultipleScanTableInputFormat.STOP_ROWS, MultipleScanTableInputFormat.convertRowsToString(stopRows));
TableMapReduceUtil.addDependencyJars(job);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static Scan createDefaultScan() {
final Scan scan = new Scan();
scan.setCaching(5000);
scan.setCacheBlocks(false);
scan.setAttribute(Scan.SCAN_ATTRIBUTES_METRICS_ENABLE, Bytes.toBytes(Boolean.TRUE));
return scan;
}
}
package org.apache.hadoop.hbase.mapreduce;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map.Entry;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
import org.apache.hadoop.hbase.util.Base64;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import com.google.ten.common.collect.ArrayListMultimap;
import com.google.ten.common.collect.Iterables;
import com.google.ten.common.collect.Lists;
import com.google.ten.common.collect.Multimap;
/**
* Convert HBase tabular data into a format that is consumable by Map/Reduce.
*/
public class MultipleScanTableInputFormat extends TableInputFormatBase implements Configurable {
public static final String START_ROWS = "hbase.mapreduce.startrows";
public static final String STOP_ROWS = "hbase.mapreduce.stoprows";
public static final String PARTITION_BY_REGION_SERVER = "hbase.mapreduce.partitionByRegionServer";
public static final String PARTITIONS_PER_REGION_SERVER = "hbase.mapreduce.partitionByRegionServer.numPartitionsPerRegionServer";
public static final int PARTITIONS_PER_REGION_SERVER_DEFAULT = 1;
private Configuration conf;
private List<byte[]> startRows;
private List<byte[]> stopRows;
private boolean partitionByRegionServer;
private int numPartitionsPerRegionServer;
/**
* Returns the current configuration.
*
* @return The current configuration.
* @see org.apache.hadoop.conf.Configurable#getConf()
*/
@Override
public Configuration getConf() {
return conf;
}
@Override
public RecordReader<ImmutableBytesWritable, Result> createRecordReader(InputSplit inputsplit, TaskAttemptContext context) throws IOException {
LOG.info(String.format("Task attempt: %s, table split: %s", context.getTaskAttemptID(), inputsplit));
// If partitioning by regionserver, we need a bunch of readers, one for each scan that will run against
// each regionserver.
if (partitionByRegionServer && inputsplit instanceof RegionServerInputSplit) {
RegionServerInputSplit combined = (RegionServerInputSplit) inputsplit;
LOG.info("Running scans against: " + combined.getRegionServer());
List<TableRecordReader> readers = Lists.newArrayList();
for (TableSplit split : combined.getSplits()) {
readers.add(createRecordReaderForRegionServer(split, context));
}
return new RegionServerPartitionedRecordReader(readers);
}
return super.createRecordReader(inputsplit, context);
}
private TableRecordReader createRecordReaderForRegionServer(TableSplit split, TaskAttemptContext context) throws IOException {
TableRecordReader trr = new TableRecordReader();
Scan sc = new Scan(this.getScan());
sc.setStartRow(split.getStartRow());
sc.setStopRow(split.getEndRow());
trr.setScan(sc);
trr.setHTable(getHTable());
return trr;
}
/**
* Sets the configuration. This is used to set the details for the table to
* be scanned.
*
* @param configuration The configuration to set.
* @see org.apache.hadoop.conf.Configurable#setConf(
* org.apache.hadoop.conf.Configuration)
*/
@Override
public void setConf(Configuration configuration) {
this.conf = configuration;
this.partitionByRegionServer = conf.getBoolean(PARTITION_BY_REGION_SERVER, false);
this.numPartitionsPerRegionServer = conf.getInt(PARTITIONS_PER_REGION_SERVER, PARTITIONS_PER_REGION_SERVER_DEFAULT);
try {
setHTable(new HTable(conf, conf.get(TableInputFormat.INPUT_TABLE)));
setScan(TableMapReduceUtil.convertStringToScan(conf.get(TableInputFormat.SCAN)));
this.startRows = convertStringToRows(conf.get(START_ROWS));
this.stopRows = convertStringToRows(conf.get(STOP_ROWS));
} catch (IOException e) {
throw new RuntimeException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
@Override
public List<InputSplit> getSplits(JobContext context) throws IOException {
List<InputSplit> source = getAggregatedSplits(context);
if (!partitionByRegionServer) {
return source;
}
// Partition by regionserver
Multimap<String, TableSplit> partitioned = ArrayListMultimap.<String, TableSplit>create();
for (InputSplit split : source) {
TableSplit cast = (TableSplit) split;
String rs = cast.getRegionLocation();
partitioned.put(rs, cast);
}
// Combine all splits for a regionserver into a single split
List<InputSplit> result = new ArrayList<InputSplit>();
for (Entry<String, Collection<TableSplit>> entry : partitioned.asMap().entrySet()) {
for (List<TableSplit> partition : Iterables.partition(entry.getValue(), (int) Math.ceil((double) entry.getValue().size() / numPartitionsPerRegionServer))) {
InputSplit split = new RegionServerInputSplit(entry.getKey(), partition);
result.add(split);
}
}
return result;
}
private List<InputSplit> getAggregatedSplits(JobContext context) throws IOException {
final List<InputSplit> aggregatedSplits = new ArrayList<InputSplit>();
final Scan scan = getScan();
for (int i = 0; i < startRows.size(); i++) {
scan.setStartRow(startRows.get(i));
scan.setStopRow(stopRows.get(i));
setScan(scan);
aggregatedSplits.addAll(super.getSplits(context));
}
// set the state back to where it was..
scan.setStopRow(null);
scan.setStartRow(null);
setScan(scan);
return aggregatedSplits;
}
@SuppressWarnings("unchecked")
private static List<byte[]> convertStringToRows(String base64) throws IOException, ClassNotFoundException {
ByteArrayInputStream bis = new ByteArrayInputStream(Base64.decode(base64));
ObjectInputStream is = new ObjectInputStream(bis);
return (List<byte[]>) is.readObject();
}
public static String convertRowsToString(List<byte[]> rows) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(out);
oos.writeObject(rows);
return Base64.encodeBytes(out.toByteArray());
}
}
package org.apache.hadoop.hbase.mapreduce;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.List;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputSplit;
import com.google.common.collect.Lists;
import com.google.ten.common.base.Objects;
import com.google.ten.common.base.Objects.ToStringHelper;
public class RegionServerInputSplit extends InputSplit implements Writable {
private String regionServer;
private List<TableSplit> splits;
// Used for writable
public RegionServerInputSplit() {
this.regionServer = null;
this.splits = null;
}
public RegionServerInputSplit(String regionServer, List<TableSplit> splits) {
this.regionServer = regionServer;
this.splits = splits;
}
@Override
public long getLength() throws IOException, InterruptedException {
return splits.size();
}
@Override
public String[] getLocations() throws IOException, InterruptedException {
return new String[] { regionServer };
}
public List<TableSplit> getSplits() {
return splits;
}
public String getRegionServer() {
return regionServer;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeUTF(regionServer);
out.writeInt(splits.size());
for (TableSplit split : splits) {
split.write(out);
}
}
@Override
public void readFields(DataInput in) throws IOException {
regionServer = in.readUTF();
int size = in.readInt();
splits = Lists.newArrayListWithCapacity(size);
for (int i = 0; i < size; i++) {
TableSplit split = new TableSplit();
split.readFields(in);
splits.add(i, split);
}
}
@Override
public String toString() {
ToStringHelper helper = Objects.toStringHelper(this.getClass());
helper.add("regionServer", regionServer);
helper.add("tableSplits", splits);
return helper.toString();
}
}
package org.apache.hadoop.hbase.mapreduce;
import java.io.IOException;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.client.HTable;
import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
public class RegionServerPartitionedRecordReader extends TableRecordReader {
public static final String REGION_SERVER_PARTITION_INDEX = "hbase.tablemapper.regionserverpartition.index";
private static final Log LOG = LogFactory.getLog(RegionServerPartitionedRecordReader.class);
private int index;
private List<TableRecordReader> readers;
private TaskAttemptContext context;
private boolean firstLogged = false;
public RegionServerPartitionedRecordReader(List<TableRecordReader> readers) {
this.readers = readers;
this.context = null;
}
@Override
public void initialize(InputSplit inputsplit, TaskAttemptContext context) throws IOException, InterruptedException {
index = 0;
LOG.info("Initializing TableRecordReader " + index + ": " + readers.get(index));
readers.get(index).initialize(inputsplit, context);
this.context = context;
this.context.getConfiguration().setInt(REGION_SERVER_PARTITION_INDEX, index);
}
@Override
public void restart(byte[] firstRow) throws IOException {
LOG.info("Restarting TableRecordReader " + index);
readers.get(index).restart(firstRow);
}
@Override
public ImmutableBytesWritable getCurrentKey() throws IOException, InterruptedException {
return readers.get(index).getCurrentKey();
}
@Override
public Result getCurrentValue() throws IOException, InterruptedException {
return readers.get(index).getCurrentValue();
}
@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
if (!firstLogged) {
LOG.info("First call to nextKeyValue, with reader " + index + ": " + readers.get(index));
firstLogged = true;
}
boolean moved = false;
while (index < readers.size()) {
if (moved) {
LOG.info("Moving to next TableRecordReader, " + index + " out of " + readers.size());
readers.get(index).initialize(null, context);
moved = false;
}
boolean result = readers.get(index).nextKeyValue();
if (result) {
return result;
}
readers.get(index).close();
index++;
moved = true;
context.getConfiguration().setInt(REGION_SERVER_PARTITION_INDEX, index);
}
LOG.info("Finished all TableRecordReaders");
return false;
}
@Override
public float getProgress() {
return (float) index / (float) readers.size();
}
@Override
public void setHTable(HTable htable) {
readers.get(index).setHTable(htable);
}
@Override
public void setScan(Scan scan) {
readers.get(index).setScan(scan);
}
@Override
public void close() {
// Ensure all are closed
for (TableRecordReader reader : readers) {
reader.close();
}
}
}
@vijay-jayaraman
Copy link

Hi Guys,

Can anyone please explain, how to run the above program? (With a Driver Program)

Thanks,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment