Package pyspark :: Module context
[frames] | no frames]

Source Code for Module pyspark.context

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  import os 
 19  import shutil 
 20  import sys 
 21  from threading import Lock 
 22  from tempfile import NamedTemporaryFile 
 23   
 24  from pyspark import accumulators 
 25  from pyspark.accumulators import Accumulator 
 26  from pyspark.broadcast import Broadcast 
 27  from pyspark.conf import SparkConf 
 28  from pyspark.files import SparkFiles 
 29  from pyspark.java_gateway import launch_gateway 
 30  from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer 
 31  from pyspark.storagelevel import StorageLevel 
 32  from pyspark.rdd import RDD 
 33   
 34  from py4j.java_collections import ListConverter 
35 36 37 -class SparkContext(object):
38 """ 39 Main entry point for Spark functionality. A SparkContext represents the 40 connection to a Spark cluster, and can be used to create L{RDD}s and 41 broadcast variables on that cluster. 42 """ 43 44 _gateway = None 45 _jvm = None 46 _writeToFile = None 47 _next_accum_id = 0 48 _active_spark_context = None 49 _lock = Lock() 50 _python_includes = None # zip and egg files that need to be added to PYTHONPATH 51 52
53 - def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, 54 environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None):
55 """ 56 Create a new SparkContext. At least the master and app name should be set, 57 either through the named parameters here or through C{conf}. 58 59 @param master: Cluster URL to connect to 60 (e.g. mesos://host:port, spark://host:port, local[4]). 61 @param appName: A name for your job, to display on the cluster web UI. 62 @param sparkHome: Location where Spark is installed on cluster nodes. 63 @param pyFiles: Collection of .zip or .py files to send to the cluster 64 and add to PYTHONPATH. These can be paths on the local file 65 system or HDFS, HTTP, HTTPS, or FTP URLs. 66 @param environment: A dictionary of environment variables to set on 67 worker nodes. 68 @param batchSize: The number of Python objects represented as a single 69 Java object. Set 1 to disable batching or -1 to use an 70 unlimited batch size. 71 @param serializer: The serializer for RDDs. 72 @param conf: A L{SparkConf} object setting Spark properties. 73 74 75 >>> from pyspark.context import SparkContext 76 >>> sc = SparkContext('local', 'test') 77 78 >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL 79 Traceback (most recent call last): 80 ... 81 ValueError:... 82 """ 83 SparkContext._ensure_initialized(self) 84 85 self.environment = environment or {} 86 self._conf = conf or SparkConf(_jvm=self._jvm) 87 self._batchSize = batchSize # -1 represents an unlimited batch size 88 self._unbatched_serializer = serializer 89 if batchSize == 1: 90 self.serializer = self._unbatched_serializer 91 else: 92 self.serializer = BatchedSerializer(self._unbatched_serializer, 93 batchSize) 94 95 # Set any parameters passed directly to us on the conf 96 if master: 97 self._conf.setMaster(master) 98 if appName: 99 self._conf.setAppName(appName) 100 if sparkHome: 101 self._conf.setSparkHome(sparkHome) 102 if environment: 103 for key, value in environment.iteritems(): 104 self._conf.setExecutorEnv(key, value) 105 106 # Check that we have at least the required parameters 107 if not self._conf.contains("spark.master"): 108 raise Exception("A master URL must be set in your configuration") 109 if not self._conf.contains("spark.app.name"): 110 raise Exception("An application name must be set in your configuration") 111 112 # Read back our properties from the conf in case we loaded some of them from 113 # the classpath or an external config file 114 self.master = self._conf.get("spark.master") 115 self.appName = self._conf.get("spark.app.name") 116 self.sparkHome = self._conf.get("spark.home", None) 117 for (k, v) in self._conf.getAll(): 118 if k.startswith("spark.executorEnv."): 119 varName = k[len("spark.executorEnv."):] 120 self.environment[varName] = v 121 122 # Create the Java SparkContext through Py4J 123 self._jsc = self._jvm.JavaSparkContext(self._conf._jconf) 124 125 # Create a single Accumulator in Java that we'll send all our updates through; 126 # they will be passed back to us through a TCP server 127 self._accumulatorServer = accumulators._start_update_server() 128 (host, port) = self._accumulatorServer.server_address 129 self._javaAccumulator = self._jsc.accumulator( 130 self._jvm.java.util.ArrayList(), 131 self._jvm.PythonAccumulatorParam(host, port)) 132 133 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') 134 135 # Broadcast's __reduce__ method stores Broadcast instances here. 136 # This allows other code to determine which Broadcast instances have 137 # been pickled, so it can determine which Java broadcast objects to 138 # send. 139 self._pickled_broadcast_vars = set() 140 141 SparkFiles._sc = self 142 root_dir = SparkFiles.getRootDirectory() 143 sys.path.append(root_dir) 144 145 # Deploy any code dependencies specified in the constructor 146 self._python_includes = list() 147 for path in (pyFiles or []): 148 self.addPyFile(path) 149 150 # Create a temporary directory inside spark.local.dir: 151 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) 152 self._temp_dir = \ 153 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
154 155 @classmethod
156 - def _ensure_initialized(cls, instance=None):
157 with SparkContext._lock: 158 if not SparkContext._gateway: 159 SparkContext._gateway = launch_gateway() 160 SparkContext._jvm = SparkContext._gateway.jvm 161 SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile 162 163 if instance: 164 if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: 165 raise ValueError("Cannot run multiple SparkContexts at once") 166 else: 167 SparkContext._active_spark_context = instance
168 169 @classmethod
170 - def setSystemProperty(cls, key, value):
171 """ 172 Set a Java system property, such as spark.executor.memory. This must 173 must be invoked before instantiating SparkContext. 174 """ 175 SparkContext._ensure_initialized() 176 SparkContext._jvm.java.lang.System.setProperty(key, value)
177 178 @property
179 - def defaultParallelism(self):
180 """ 181 Default level of parallelism to use when not given by user (e.g. for 182 reduce tasks) 183 """ 184 return self._jsc.sc().defaultParallelism()
185
186 - def __del__(self):
187 self.stop()
188
189 - def stop(self):
190 """ 191 Shut down the SparkContext. 192 """ 193 if self._jsc: 194 self._jsc.stop() 195 self._jsc = None 196 if self._accumulatorServer: 197 self._accumulatorServer.shutdown() 198 self._accumulatorServer = None 199 with SparkContext._lock: 200 SparkContext._active_spark_context = None
201
202 - def parallelize(self, c, numSlices=None):
203 """ 204 Distribute a local Python collection to form an RDD. 205 206 >>> sc.parallelize(range(5), 5).glom().collect() 207 [[0], [1], [2], [3], [4]] 208 """ 209 numSlices = numSlices or self.defaultParallelism 210 # Calling the Java parallelize() method with an ArrayList is too slow, 211 # because it sends O(n) Py4J commands. As an alternative, serialized 212 # objects are written to a file and loaded through textFile(). 213 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) 214 # Make sure we distribute data evenly if it's smaller than self.batchSize 215 if "__len__" not in dir(c): 216 c = list(c) # Make it a list so we can compute its length 217 batchSize = min(len(c) // numSlices, self._batchSize) 218 if batchSize > 1: 219 serializer = BatchedSerializer(self._unbatched_serializer, 220 batchSize) 221 else: 222 serializer = self._unbatched_serializer 223 serializer.dump_stream(c, tempFile) 224 tempFile.close() 225 readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile 226 jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) 227 return RDD(jrdd, self, serializer)
228
229 - def textFile(self, name, minSplits=None):
230 """ 231 Read a text file from HDFS, a local file system (available on all 232 nodes), or any Hadoop-supported file system URI, and return it as an 233 RDD of Strings. 234 """ 235 minSplits = minSplits or min(self.defaultParallelism, 2) 236 return RDD(self._jsc.textFile(name, minSplits), self, 237 UTF8Deserializer())
238
239 - def _checkpointFile(self, name, input_deserializer):
240 jrdd = self._jsc.checkpointFile(name) 241 return RDD(jrdd, self, input_deserializer)
242
243 - def union(self, rdds):
244 """ 245 Build the union of a list of RDDs. 246 247 This supports unions() of RDDs with different serialized formats, 248 although this forces them to be reserialized using the default 249 serializer: 250 251 >>> path = os.path.join(tempdir, "union-text.txt") 252 >>> with open(path, "w") as testFile: 253 ... testFile.write("Hello") 254 >>> textFile = sc.textFile(path) 255 >>> textFile.collect() 256 [u'Hello'] 257 >>> parallelized = sc.parallelize(["World!"]) 258 >>> sorted(sc.union([textFile, parallelized]).collect()) 259 [u'Hello', 'World!'] 260 """ 261 first_jrdd_deserializer = rdds[0]._jrdd_deserializer 262 if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): 263 rdds = [x._reserialize() for x in rdds] 264 first = rdds[0]._jrdd 265 rest = [x._jrdd for x in rdds[1:]] 266 rest = ListConverter().convert(rest, self._gateway._gateway_client) 267 return RDD(self._jsc.union(first, rest), self, 268 rdds[0]._jrdd_deserializer)
269
270 - def broadcast(self, value):
271 """ 272 Broadcast a read-only variable to the cluster, returning a 273 L{Broadcast<pyspark.broadcast.Broadcast>} 274 object for reading it in distributed functions. The variable will be 275 sent to each cluster only once. 276 """ 277 pickleSer = PickleSerializer() 278 pickled = pickleSer.dumps(value) 279 jbroadcast = self._jsc.broadcast(bytearray(pickled)) 280 return Broadcast(jbroadcast.id(), value, jbroadcast, 281 self._pickled_broadcast_vars)
282
283 - def accumulator(self, value, accum_param=None):
284 """ 285 Create an L{Accumulator} with the given initial value, using a given 286 L{AccumulatorParam} helper object to define how to add values of the 287 data type if provided. Default AccumulatorParams are used for integers 288 and floating-point numbers if you do not provide one. For other types, 289 a custom AccumulatorParam can be used. 290 """ 291 if accum_param is None: 292 if isinstance(value, int): 293 accum_param = accumulators.INT_ACCUMULATOR_PARAM 294 elif isinstance(value, float): 295 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM 296 elif isinstance(value, complex): 297 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM 298 else: 299 raise Exception("No default accumulator param for type %s" % type(value)) 300 SparkContext._next_accum_id += 1 301 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
302
303 - def addFile(self, path):
304 """ 305 Add a file to be downloaded with this Spark job on every node. 306 The C{path} passed can be either a local file, a file in HDFS 307 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or 308 FTP URI. 309 310 To access the file in Spark jobs, use 311 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its 312 download location. 313 314 >>> from pyspark import SparkFiles 315 >>> path = os.path.join(tempdir, "test.txt") 316 >>> with open(path, "w") as testFile: 317 ... testFile.write("100") 318 >>> sc.addFile(path) 319 >>> def func(iterator): 320 ... with open(SparkFiles.get("test.txt")) as testFile: 321 ... fileVal = int(testFile.readline()) 322 ... return [x * 100 for x in iterator] 323 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() 324 [100, 200, 300, 400] 325 """ 326 self._jsc.sc().addFile(path)
327
328 - def clearFiles(self):
329 """ 330 Clear the job's list of files added by L{addFile} or L{addPyFile} so 331 that they do not get downloaded to any new nodes. 332 """ 333 # TODO: remove added .py or .zip files from the PYTHONPATH? 334 self._jsc.sc().clearFiles()
335
336 - def addPyFile(self, path):
337 """ 338 Add a .py or .zip dependency for all tasks to be executed on this 339 SparkContext in the future. The C{path} passed can be either a local 340 file, a file in HDFS (or other Hadoop-supported filesystems), or an 341 HTTP, HTTPS or FTP URI. 342 """ 343 self.addFile(path) 344 (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix 345 346 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): 347 self._python_includes.append(filename) 348 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
349
350 - def setCheckpointDir(self, dirName):
351 """ 352 Set the directory under which RDDs are going to be checkpointed. The 353 directory must be a HDFS path if running on a cluster. 354 """ 355 self._jsc.sc().setCheckpointDir(dirName)
356
357 - def _getJavaStorageLevel(self, storageLevel):
358 """ 359 Returns a Java StorageLevel based on a pyspark.StorageLevel. 360 """ 361 if not isinstance(storageLevel, StorageLevel): 362 raise Exception("storageLevel must be of type pyspark.StorageLevel") 363 364 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel 365 return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, 366 storageLevel.deserialized, storageLevel.replication)
367
368 -def _test():
369 import atexit 370 import doctest 371 import tempfile 372 globs = globals().copy() 373 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 374 globs['tempdir'] = tempfile.mkdtemp() 375 atexit.register(lambda: shutil.rmtree(globs['tempdir'])) 376 (failure_count, test_count) = doctest.testmod(globs=globs) 377 globs['sc'].stop() 378 if failure_count: 379 exit(-1)
380 381 382 if __name__ == "__main__": 383 _test() 384