1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from pyspark import SparkContext
19 from pyspark.mllib._common import \
20 _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
21 _serialize_double_matrix, _deserialize_double_matrix, \
22 _serialize_double_vector, _deserialize_double_vector, \
23 _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
24 _serialize_tuple, RatingDeserializer
25 from pyspark.rdd import RDD
28 """A matrix factorisation model trained by regularized alternating
29 least-squares.
30
31 >>> r1 = (1, 1, 1.0)
32 >>> r2 = (1, 2, 2.0)
33 >>> r3 = (2, 1, 2.0)
34 >>> ratings = sc.parallelize([r1, r2, r3])
35 >>> model = ALS.trainImplicit(ratings, 1)
36 >>> model.predict(2,2) is not None
37 True
38 >>> testset = sc.parallelize([(1, 2), (1, 1)])
39 >>> model.predictAll(testset).count() == 2
40 True
41 """
42
44 self._context = sc
45 self._java_model = java_model
46
48 self._context._gateway.detach(self._java_model)
49
51 return self._java_model.predict(user, product)
52
54 usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
55 return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
56 self._context, RatingDeserializer())
57
59 @classmethod
60 - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
61 sc = ratings.context
62 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
63 mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd,
64 rank, iterations, lambda_, blocks)
65 return MatrixFactorizationModel(sc, mod)
66
67 @classmethod
68 - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
69 sc = ratings.context
70 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
71 mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd,
72 rank, iterations, lambda_, blocks, alpha)
73 return MatrixFactorizationModel(sc, mod)
74
76 import doctest
77 globs = globals().copy()
78 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
79 (failure_count, test_count) = doctest.testmod(globs=globs,
80 optionflags=doctest.ELLIPSIS)
81 globs['sc'].stop()
82 if failure_count:
83 exit(-1)
84
85 if __name__ == "__main__":
86 _test()
87