GaussianMixtureModel¶
-
class
pyspark.mllib.clustering.
GaussianMixtureModel
(java_model: py4j.java_gateway.JavaObject)[source]¶ A clustering model derived from the Gaussian Mixture Model method.
New in version 1.3.0.
Examples
>>> from pyspark.mllib.linalg import Vectors, DenseMatrix >>> from numpy.testing import assert_equal >>> from shutil import rmtree >>> import os, tempfile
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2) >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, ... maxIterations=50, seed=10) >>> labels = model.predict(clusterdata_1).collect() >>> labels[0]==labels[1] False >>> labels[1]==labels[2] False >>> labels[4]==labels[5] True >>> model.predict([-0.1,-0.05]) 0 >>> softPredicted = model.predictSoft([-0.1,-0.05]) >>> abs(softPredicted[0] - 1.0) < 0.03 True >>> abs(softPredicted[1] - 0.0) < 0.03 True >>> abs(softPredicted[2] - 0.0) < 0.03 True
>>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = GaussianMixtureModel.load(sc, path) >>> assert_equal(model.weights, sameModel.weights) >>> mus, sigmas = list( ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) >>> sameMus, sameSigmas = list( ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) >>> mus == sameMus True >>> sigmas == sameSigmas True >>> from shutil import rmtree >>> try: ... rmtree(path) ... except OSError: ... pass
>>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, ... 5.7048, 4.6567, 5.5026, ... 4.5605, 5.2043, 6.2734]) >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=4) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1] True >>> labels[2]==labels[3]==labels[4] True
Methods
call
(name, *a)Call method of java_model
load
(sc, path)Load the GaussianMixtureModel from disk.
predict
(x)Find the cluster to which the point ‘x’ or each point in RDD ‘x’ has maximum membership in this model.
predictSoft
(x)Find the membership of point ‘x’ or each point in RDD ‘x’ to all mixture components.
save
(sc, path)Save this model to the given path.
Attributes
Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i.
Number of gaussians in mixture.
Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1.
Methods Documentation
-
call
(name: str, *a: Any) → Any¶ Call method of java_model
-
classmethod
load
(sc: pyspark.context.SparkContext, path: str) → pyspark.mllib.clustering.GaussianMixtureModel[source]¶ Load the GaussianMixtureModel from disk.
New in version 1.5.0.
- Parameters
- sc
SparkContext
- pathstr
Path to where the model is stored.
- sc
-
predict
(x: Union[VectorLike, pyspark.rdd.RDD[VectorLike]]) → Union[numpy.int64, pyspark.rdd.RDD[int]][source]¶ Find the cluster to which the point ‘x’ or each point in RDD ‘x’ has maximum membership in this model.
New in version 1.3.0.
- Parameters
- x
pyspark.mllib.linalg.Vector
orpyspark.RDD
A feature vector or an RDD of vectors representing data points.
- x
- Returns
- numpy.float64 orpy:class:pyspark.RDD of int
Predicted cluster label or an RDD of predicted cluster labels if the input is an RDD.
-
predictSoft
(x: Union[VectorLike, pyspark.rdd.RDD[VectorLike]]) → Union[numpy.ndarray, pyspark.rdd.RDD[array.array]][source]¶ Find the membership of point ‘x’ or each point in RDD ‘x’ to all mixture components.
New in version 1.3.0.
- Parameters
- x
pyspark.mllib.linalg.Vector
orpyspark.RDD
A feature vector or an RDD of vectors representing data points.
- x
- Returns
- numpy.ndarray orpy:class:pyspark.RDD
The membership value to all mixture components for vector ‘x’ or each vector in RDD ‘x’.
-
save
(sc: pyspark.context.SparkContext, path: str) → None¶ Save this model to the given path.
New in version 1.3.0.
Attributes Documentation
-
gaussians
¶ Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i.
New in version 1.4.0.
-
k
¶ Number of gaussians in mixture.
New in version 1.4.0.
-
weights
¶ Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1.
New in version 1.4.0.
-