1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 import os
19 import shutil
20 import sys
21 from threading import Lock
22 from tempfile import NamedTemporaryFile
23
24 from pyspark import accumulators
25 from pyspark.accumulators import Accumulator
26 from pyspark.broadcast import Broadcast
27 from pyspark.conf import SparkConf
28 from pyspark.files import SparkFiles
29 from pyspark.java_gateway import launch_gateway
30 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
31 from pyspark.storagelevel import StorageLevel
32 from pyspark.rdd import RDD
33
34 from py4j.java_collections import ListConverter
35
36
37 -class SparkContext(object):
38 """
39 Main entry point for Spark functionality. A SparkContext represents the
40 connection to a Spark cluster, and can be used to create L{RDD}s and
41 broadcast variables on that cluster.
42 """
43
44 _gateway = None
45 _jvm = None
46 _writeToFile = None
47 _next_accum_id = 0
48 _active_spark_context = None
49 _lock = Lock()
50 _python_includes = None
51
52
53 - def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
54 environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None):
55 """
56 Create a new SparkContext. At least the master and app name should be set,
57 either through the named parameters here or through C{conf}.
58
59 @param master: Cluster URL to connect to
60 (e.g. mesos://host:port, spark://host:port, local[4]).
61 @param appName: A name for your job, to display on the cluster web UI.
62 @param sparkHome: Location where Spark is installed on cluster nodes.
63 @param pyFiles: Collection of .zip or .py files to send to the cluster
64 and add to PYTHONPATH. These can be paths on the local file
65 system or HDFS, HTTP, HTTPS, or FTP URLs.
66 @param environment: A dictionary of environment variables to set on
67 worker nodes.
68 @param batchSize: The number of Python objects represented as a single
69 Java object. Set 1 to disable batching or -1 to use an
70 unlimited batch size.
71 @param serializer: The serializer for RDDs.
72 @param conf: A L{SparkConf} object setting Spark properties.
73
74
75 >>> from pyspark.context import SparkContext
76 >>> sc = SparkContext('local', 'test')
77
78 >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
79 Traceback (most recent call last):
80 ...
81 ValueError:...
82 """
83 SparkContext._ensure_initialized(self)
84
85 self.environment = environment or {}
86 self._conf = conf or SparkConf(_jvm=self._jvm)
87 self._batchSize = batchSize
88 self._unbatched_serializer = serializer
89 if batchSize == 1:
90 self.serializer = self._unbatched_serializer
91 else:
92 self.serializer = BatchedSerializer(self._unbatched_serializer,
93 batchSize)
94
95
96 if master:
97 self._conf.setMaster(master)
98 if appName:
99 self._conf.setAppName(appName)
100 if sparkHome:
101 self._conf.setSparkHome(sparkHome)
102 if environment:
103 for key, value in environment.iteritems():
104 self._conf.setExecutorEnv(key, value)
105
106
107 if not self._conf.contains("spark.master"):
108 raise Exception("A master URL must be set in your configuration")
109 if not self._conf.contains("spark.app.name"):
110 raise Exception("An application name must be set in your configuration")
111
112
113
114 self.master = self._conf.get("spark.master")
115 self.appName = self._conf.get("spark.app.name")
116 self.sparkHome = self._conf.get("spark.home", None)
117 for (k, v) in self._conf.getAll():
118 if k.startswith("spark.executorEnv."):
119 varName = k[len("spark.executorEnv."):]
120 self.environment[varName] = v
121
122
123 self._jsc = self._jvm.JavaSparkContext(self._conf._jconf)
124
125
126
127 self._accumulatorServer = accumulators._start_update_server()
128 (host, port) = self._accumulatorServer.server_address
129 self._javaAccumulator = self._jsc.accumulator(
130 self._jvm.java.util.ArrayList(),
131 self._jvm.PythonAccumulatorParam(host, port))
132
133 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
134
135
136
137
138
139 self._pickled_broadcast_vars = set()
140
141 SparkFiles._sc = self
142 root_dir = SparkFiles.getRootDirectory()
143 sys.path.append(root_dir)
144
145
146 self._python_includes = list()
147 for path in (pyFiles or []):
148 self.addPyFile(path)
149
150
151 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
152 self._temp_dir = \
153 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
154
155 @classmethod
156 - def _ensure_initialized(cls, instance=None):
157 with SparkContext._lock:
158 if not SparkContext._gateway:
159 SparkContext._gateway = launch_gateway()
160 SparkContext._jvm = SparkContext._gateway.jvm
161 SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
162
163 if instance:
164 if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
165 raise ValueError("Cannot run multiple SparkContexts at once")
166 else:
167 SparkContext._active_spark_context = instance
168
169 @classmethod
170 - def setSystemProperty(cls, key, value):
171 """
172 Set a Java system property, such as spark.executor.memory. This must
173 must be invoked before instantiating SparkContext.
174 """
175 SparkContext._ensure_initialized()
176 SparkContext._jvm.java.lang.System.setProperty(key, value)
177
178 @property
180 """
181 Default level of parallelism to use when not given by user (e.g. for
182 reduce tasks)
183 """
184 return self._jsc.sc().defaultParallelism()
185
188
190 """
191 Shut down the SparkContext.
192 """
193 if self._jsc:
194 self._jsc.stop()
195 self._jsc = None
196 if self._accumulatorServer:
197 self._accumulatorServer.shutdown()
198 self._accumulatorServer = None
199 with SparkContext._lock:
200 SparkContext._active_spark_context = None
201
202 - def parallelize(self, c, numSlices=None):
203 """
204 Distribute a local Python collection to form an RDD.
205
206 >>> sc.parallelize(range(5), 5).glom().collect()
207 [[0], [1], [2], [3], [4]]
208 """
209 numSlices = numSlices or self.defaultParallelism
210
211
212
213 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
214
215 if "__len__" not in dir(c):
216 c = list(c)
217 batchSize = min(len(c) // numSlices, self._batchSize)
218 if batchSize > 1:
219 serializer = BatchedSerializer(self._unbatched_serializer,
220 batchSize)
221 else:
222 serializer = self._unbatched_serializer
223 serializer.dump_stream(c, tempFile)
224 tempFile.close()
225 readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
226 jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
227 return RDD(jrdd, self, serializer)
228
229 - def textFile(self, name, minSplits=None):
230 """
231 Read a text file from HDFS, a local file system (available on all
232 nodes), or any Hadoop-supported file system URI, and return it as an
233 RDD of Strings.
234 """
235 minSplits = minSplits or min(self.defaultParallelism, 2)
236 return RDD(self._jsc.textFile(name, minSplits), self,
237 UTF8Deserializer())
238
239 - def _checkpointFile(self, name, input_deserializer):
240 jrdd = self._jsc.checkpointFile(name)
241 return RDD(jrdd, self, input_deserializer)
242
243 - def union(self, rdds):
244 """
245 Build the union of a list of RDDs.
246
247 This supports unions() of RDDs with different serialized formats,
248 although this forces them to be reserialized using the default
249 serializer:
250
251 >>> path = os.path.join(tempdir, "union-text.txt")
252 >>> with open(path, "w") as testFile:
253 ... testFile.write("Hello")
254 >>> textFile = sc.textFile(path)
255 >>> textFile.collect()
256 [u'Hello']
257 >>> parallelized = sc.parallelize(["World!"])
258 >>> sorted(sc.union([textFile, parallelized]).collect())
259 [u'Hello', 'World!']
260 """
261 first_jrdd_deserializer = rdds[0]._jrdd_deserializer
262 if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
263 rdds = [x._reserialize() for x in rdds]
264 first = rdds[0]._jrdd
265 rest = [x._jrdd for x in rdds[1:]]
266 rest = ListConverter().convert(rest, self._gateway._gateway_client)
267 return RDD(self._jsc.union(first, rest), self,
268 rdds[0]._jrdd_deserializer)
269
270 - def broadcast(self, value):
271 """
272 Broadcast a read-only variable to the cluster, returning a
273 L{Broadcast<pyspark.broadcast.Broadcast>}
274 object for reading it in distributed functions. The variable will be
275 sent to each cluster only once.
276 """
277 pickleSer = PickleSerializer()
278 pickled = pickleSer.dumps(value)
279 jbroadcast = self._jsc.broadcast(bytearray(pickled))
280 return Broadcast(jbroadcast.id(), value, jbroadcast,
281 self._pickled_broadcast_vars)
282
283 - def accumulator(self, value, accum_param=None):
284 """
285 Create an L{Accumulator} with the given initial value, using a given
286 L{AccumulatorParam} helper object to define how to add values of the
287 data type if provided. Default AccumulatorParams are used for integers
288 and floating-point numbers if you do not provide one. For other types,
289 a custom AccumulatorParam can be used.
290 """
291 if accum_param is None:
292 if isinstance(value, int):
293 accum_param = accumulators.INT_ACCUMULATOR_PARAM
294 elif isinstance(value, float):
295 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
296 elif isinstance(value, complex):
297 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
298 else:
299 raise Exception("No default accumulator param for type %s" % type(value))
300 SparkContext._next_accum_id += 1
301 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
302
303 - def addFile(self, path):
304 """
305 Add a file to be downloaded with this Spark job on every node.
306 The C{path} passed can be either a local file, a file in HDFS
307 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
308 FTP URI.
309
310 To access the file in Spark jobs, use
311 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
312 download location.
313
314 >>> from pyspark import SparkFiles
315 >>> path = os.path.join(tempdir, "test.txt")
316 >>> with open(path, "w") as testFile:
317 ... testFile.write("100")
318 >>> sc.addFile(path)
319 >>> def func(iterator):
320 ... with open(SparkFiles.get("test.txt")) as testFile:
321 ... fileVal = int(testFile.readline())
322 ... return [x * 100 for x in iterator]
323 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
324 [100, 200, 300, 400]
325 """
326 self._jsc.sc().addFile(path)
327
328 - def clearFiles(self):
329 """
330 Clear the job's list of files added by L{addFile} or L{addPyFile} so
331 that they do not get downloaded to any new nodes.
332 """
333
334 self._jsc.sc().clearFiles()
335
336 - def addPyFile(self, path):
337 """
338 Add a .py or .zip dependency for all tasks to be executed on this
339 SparkContext in the future. The C{path} passed can be either a local
340 file, a file in HDFS (or other Hadoop-supported filesystems), or an
341 HTTP, HTTPS or FTP URI.
342 """
343 self.addFile(path)
344 (dirname, filename) = os.path.split(path)
345
346 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
347 self._python_includes.append(filename)
348 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
349
350 - def setCheckpointDir(self, dirName):
351 """
352 Set the directory under which RDDs are going to be checkpointed. The
353 directory must be a HDFS path if running on a cluster.
354 """
355 self._jsc.sc().setCheckpointDir(dirName)
356
357 - def _getJavaStorageLevel(self, storageLevel):
358 """
359 Returns a Java StorageLevel based on a pyspark.StorageLevel.
360 """
361 if not isinstance(storageLevel, StorageLevel):
362 raise Exception("storageLevel must be of type pyspark.StorageLevel")
363
364 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
365 return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory,
366 storageLevel.deserialized, storageLevel.replication)
367
369 import atexit
370 import doctest
371 import tempfile
372 globs = globals().copy()
373 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
374 globs['tempdir'] = tempfile.mkdtemp()
375 atexit.register(lambda: shutil.rmtree(globs['tempdir']))
376 (failure_count, test_count) = doctest.testmod(globs=globs)
377 globs['sc'].stop()
378 if failure_count:
379 exit(-1)
380
381
382 if __name__ == "__main__":
383 _test()
384