package org.gradoop.flink.algorithms.gelly.randomjump.functions;

import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.pregel.ComputeFunction;
import org.apache.flink.graph.pregel.MessageIterator;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.NullValue;
import org.gradoop.flink.algorithms.gelly.randomjump.KRandomJumpGellyVCI;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

 * Compute function for {@link KRandomJumpGellyVCI}. Vertex values are of type
 * {@link VCIVertexValue}, with fields:
 * <ul>
 * <li>f0: {@code Boolean} set to {@code true} if the vertex was visited, to {@code false} otherwise
 * <li>f1: {@code List<Long>} containing all long indices from already visited outgoing edges
 * </ul>
 * A message of type {@code NullValue} from one vertex to another is a walk respective a jump to
 * this other vertex and therefor considered as one of {@code k} walkers.
public class VCIComputeFunction extends ComputeFunction<Long, VCIVertexValue, Long, NullValue> {

   * Name of the broadcast set containing the indices for the k starting vertices.
  public static final String START_INDICES_BROADCAST_SET = "startIndices";

   * Name of the broadcast set containing the graphs vertex indices.
  public static final String VERTEX_INDICES_BROADCAST_SET = "vertexIndices";

   * Name for the LongSumAggregator used for counting the visited vertices.
  public static final String VISITED_VERTICES_AGGREGATOR_NAME = "visitedVerticesAggregator";

   * Probability for jumping to a random vertex instead of walking to a random neighbor.
  private final double jumpProbability;

   * Number of vertices to visit at least.
  private final long verticesToVisit;

   * List with the indices for the k starting vertices.
  private List<Long> startIndices;

   * List containing the graphs vertex indices.
  private List<Long> vertexIndices;

   * The LongSumAggregator used for counting the visited vertices.
  private LongSumAggregator visitedVerticesAggregator;

   * Keeping track of the currently visited vertices at the beginning at each superstep.
  private long currentVisitedCount;

   * Creates an instance of VCIComputeFunction.
   * @param jumpProbability Probability for jumping to random vertex instead of walking to random
   *                        neighbor.
   * @param verticesToVisit Number of vertices to visit via walk or jump.
  public VCIComputeFunction(double jumpProbability, long verticesToVisit) {
    this.jumpProbability = jumpProbability;
    this.verticesToVisit = verticesToVisit;
    this.visitedVerticesAggregator = new LongSumAggregator();
    this.currentVisitedCount = 0L;

   * {@inheritDoc}
   * Reads the broadcast sets for the starting vertices and the graph vertices. Retrieves the
   * aggregator for visited vertices and accumulates the visited vertices from the previous
   * superstep.
  public void preSuperstep() {
    startIndices = (List<Long>) this.<Long>getBroadcastSet(START_INDICES_BROADCAST_SET);
    vertexIndices = (List<Long>) this.<Long>getBroadcastSet(VERTEX_INDICES_BROADCAST_SET);
    visitedVerticesAggregator = getIterationAggregator(VISITED_VERTICES_AGGREGATOR_NAME);
    LongValue previousAggregate = getPreviousIterationAggregate(VISITED_VERTICES_AGGREGATOR_NAME);
    if (previousAggregate != null) {
      currentVisitedCount += previousAggregate.getValue();

   * {@inheritDoc}
   * Initially starts the first walk or jump from the given start vertices. Afterwards starts a
   * walk or jump from a vertex, if it received messages from other vertices. Stops the
   * computation and therefor the iteration if the number of currently visited vertices exceeds
   * the number of vertices to visit.
   * @param vertex The vertex this computation in a superstep is running for.
   * @param messages Iterator over all incoming messages
  public void compute(Vertex<Long, VCIVertexValue> vertex, MessageIterator<NullValue> messages) {
    if (currentVisitedCount < verticesToVisit) {
      List<Edge<Long, Long>> edgesList = Lists.newArrayList(getEdges());
      Tuple2<VCIVertexValue, Boolean> valueWithHasChanged = Tuple2.of(vertex.getValue(), false);
      if (startIndices.contains(vertex.getId()) && !valueWithHasChanged.f0.isVisited()) {
        valueWithHasChanged = walkToRandomNeighbor(valueWithHasChanged, edgesList);
      } else if (messages.hasNext()) {
        for (NullValue msg : messages) {
          valueWithHasChanged = walkToRandomNeighbor(valueWithHasChanged, edgesList);
      if (valueWithHasChanged.f1) {

   * Performs a walk to a random neighbor by sending a message to a target from an unvisited
   * outgoing edge. Sets the vertex value as visited if necessary and updates the visited edge ids.
   * Sets a boolean flag, if the vertex value has changed.
   * Alternatively performs a jump to a random vertex with a probability given in
   * {@link #jumpProbability} or if there are no unvisited outgoing edges.
   * Returns the vertex value with the boolean flag as {@code Tuple2} eventually.
   * @param valueWithHasChanged {@code Tuple2} containing the vertex value and a boolean flag
   *                            determining if the value has changed
   * @param edgesList List of all outgoing edge for the vertex
   * @return {@code Tuple2} containing the vertex value and a boolean flag determining if the
   *         value has changed
  private Tuple2<VCIVertexValue, Boolean> walkToRandomNeighbor(
    Tuple2<VCIVertexValue, Boolean> valueWithHasChanged, List<Edge<Long, Long>> edgesList) {
    if (!valueWithHasChanged.f0.isVisited()) {
      valueWithHasChanged.f1 = true;
    if ((jumpProbability == 0d) || (jumpProbability < ThreadLocalRandom.current().nextDouble())) {
      List<Tuple2<Long, Long>> unvisitedNeighborWithEdgeId = new ArrayList<>();
      for (Edge<Long, Long> edge : edgesList) {
        if (!valueWithHasChanged.f0.getVisitedOutEdges().contains(edge.getValue())) {
          unvisitedNeighborWithEdgeId.add(Tuple2.of(edge.getTarget(), edge.getValue()));
      if (!unvisitedNeighborWithEdgeId.isEmpty()) {
        int randomIndex = ThreadLocalRandom.current().nextInt(unvisitedNeighborWithEdgeId.size());
        Long randomNeighborIndex = unvisitedNeighborWithEdgeId.get(randomIndex).f0;
        sendMessageTo(randomNeighborIndex, new NullValue());
        valueWithHasChanged.f1 = true;
      } else {
    } else {
    return valueWithHasChanged;

   * Jumps to a random vertex in the graph by sending a message to this vertex.
  private void jumpToRandomVertex() {
    int randomIndex = ThreadLocalRandom.current().nextInt(vertexIndices.size());
    Long randomVertexIndex = vertexIndices.get(randomIndex);
    sendMessageTo(randomVertexIndex, new NullValue());