Clustering - spark.ml

In this section, we introduce the pipeline API for clustering in mllib.

Table of Contents

K-means

k-means is one of the most commonly used clustering algorithms that clusters the data points into a predefined number of clusters. The MLlib implementation includes a parallelized variant of the k-means++ method called kmeans||.

KMeans is implemented as an Estimator and generates a KMeansModel as the base model.

Input Columns

Param name Type(s) Default Description
featuresCol Vector "features" Feature vector

Output Columns

Param name Type(s) Default Description
predictionCol Int "prediction" Predicted cluster center

Example

Refer to the Scala API docs for more details.

import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors

// Crates a DataFrame
val dataset: DataFrame = sqlContext.createDataFrame(Seq(
  (1, Vectors.dense(0.0, 0.0, 0.0)),
  (2, Vectors.dense(0.1, 0.1, 0.1)),
  (3, Vectors.dense(0.2, 0.2, 0.2)),
  (4, Vectors.dense(9.0, 9.0, 9.0)),
  (5, Vectors.dense(9.1, 9.1, 9.1)),
  (6, Vectors.dense(9.2, 9.2, 9.2))
)).toDF("id", "features")

// Trains a k-means model
val kmeans = new KMeans()
  .setK(2)
  .setFeaturesCol("features")
  .setPredictionCol("prediction")
val model = kmeans.fit(dataset)

// Shows the result
println("Final Centers: ")
model.clusterCenters.foreach(println)
Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala" in the Spark repo.

Refer to the Java API docs for more details.

import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

// Loads data
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParsePoint());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
StructType schema = new StructType(fields);
DataFrame dataset = sqlContext.createDataFrame(points, schema);

// Trains a k-means model
KMeans kmeans = new KMeans()
  .setK(k);
KMeansModel model = kmeans.fit(dataset);

// Shows the result
Vector[] centers = model.clusterCenters();
System.out.println("Cluster Centers: ");
for (Vector center: centers) {
  System.out.println(center);
}
Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java" in the Spark repo.

Latent Dirichlet allocation (LDA)

LDA is implemented as an Estimator that supports both EMLDAOptimizer and OnlineLDAOptimizer, and generates a LDAModel as the base models. Expert users may cast a LDAModel generated by EMLDAOptimizer to a DistributedLDAModel if needed.

Refer to the Scala API docs for more details.

import org.apache.spark.ml.clustering.LDA
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.types.{StructField, StructType}

// Loads data
val rowRDD = sc.textFile(input).filter(_.nonEmpty)
  .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_))
val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false)))
val dataset = sqlContext.createDataFrame(rowRDD, schema)

// Trains a LDA model
val lda = new LDA()
  .setK(10)
  .setMaxIter(10)
  .setFeaturesCol(FEATURES_COL)
val model = lda.fit(dataset)
val transformed = model.transform(dataset)

val ll = model.logLikelihood(dataset)
val lp = model.logPerplexity(dataset)

// describeTopics
val topics = model.describeTopics(3)

// Shows the result
topics.show(false)
transformed.show(false)
Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala" in the Spark repo.

Refer to the Java API docs for more details.

import java.util.regex.Pattern;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

private static class ParseVector implements Function<String, Row> {
  private static final Pattern separator = Pattern.compile(" ");

  @Override
  public Row call(String line) {
    String[] tok = separator.split(line);
    double[] point = new double[tok.length];
    for (int i = 0; i < tok.length; ++i) {
      point[i] = Double.parseDouble(tok[i]);
    }
    Vector[] points = {Vectors.dense(point)};
    return new GenericRow(points);
  }
}

public static void main(String[] args) {

  String inputFile = "data/mllib/sample_lda_data.txt";

  // Parses the arguments
  SparkConf conf = new SparkConf().setAppName("JavaLDAExample");
  JavaSparkContext jsc = new JavaSparkContext(conf);
  SQLContext sqlContext = new SQLContext(jsc);

  // Loads data
  JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParseVector());
  StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
  StructType schema = new StructType(fields);
  DataFrame dataset = sqlContext.createDataFrame(points, schema);

  // Trains a LDA model
  LDA lda = new LDA()
    .setK(10)
    .setMaxIter(10);
  LDAModel model = lda.fit(dataset);

  System.out.println(model.logLikelihood(dataset));
  System.out.println(model.logPerplexity(dataset));

  // Shows the result
  DataFrame topics = model.describeTopics(3);
  topics.show(false);
  model.transform(dataset).show(false);

  jsc.stop();
}
Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java" in the Spark repo.