Package pyspark :: Module sql
[frames] | no frames]

Source Code for Module pyspark.sql

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  from pyspark.rdd import RDD, PipelinedRDD 
 19  from pyspark.serializers import BatchedSerializer, PickleSerializer 
 20   
 21  from py4j.protocol import Py4JError 
 22   
 23  __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] 
24 25 26 -class SQLContext:
27 """Main entry point for SparkSQL functionality. 28 29 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as 30 tables, execute SQL over tables, cache tables, and read parquet files. 31 """ 32
33 - def __init__(self, sparkContext, sqlContext = None):
34 """Create a new SQLContext. 35 36 @param sparkContext: The SparkContext to wrap. 37 38 >>> srdd = sqlCtx.inferSchema(rdd) 39 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL 40 Traceback (most recent call last): 41 ... 42 ValueError:... 43 44 >>> bad_rdd = sc.parallelize([1,2,3]) 45 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL 46 Traceback (most recent call last): 47 ... 48 ValueError:... 49 50 >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, 51 ... "boolean" : True}]) 52 >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, 53 ... x.boolean)) 54 >>> srdd.collect()[0] 55 (1, u'string', 1.0, 1, True) 56 """ 57 self._sc = sparkContext 58 self._jsc = self._sc._jsc 59 self._jvm = self._sc._jvm 60 self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap 61 62 if sqlContext: 63 self._scala_SQLContext = sqlContext
64 65 @property
66 - def _ssql_ctx(self):
67 """Accessor for the JVM SparkSQL context. 68 69 Subclasses can override this property to provide their own 70 JVM Contexts. 71 """ 72 if not hasattr(self, '_scala_SQLContext'): 73 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) 74 return self._scala_SQLContext
75
76 - def inferSchema(self, rdd):
77 """Infer and apply a schema to an RDD of L{dict}s. 78 79 We peek at the first row of the RDD to determine the fields names 80 and types, and then use that to extract all the dictionaries. Nested 81 collections are supported, which include array, dict, list, set, and 82 tuple. 83 84 >>> srdd = sqlCtx.inferSchema(rdd) 85 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, 86 ... {"field1" : 3, "field2": "row3"}] 87 True 88 89 >>> from array import array 90 >>> srdd = sqlCtx.inferSchema(nestedRdd1) 91 >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, 92 ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] 93 True 94 95 >>> srdd = sqlCtx.inferSchema(nestedRdd2) 96 >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, 97 ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] 98 True 99 """ 100 if (rdd.__class__ is SchemaRDD): 101 raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) 102 elif not isinstance(rdd.first(), dict): 103 raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % 104 (SchemaRDD.__name__, rdd.first())) 105 106 jrdd = self._pythonToJavaMap(rdd._jrdd) 107 srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) 108 return SchemaRDD(srdd, self)
109
110 - def registerRDDAsTable(self, rdd, tableName):
111 """Registers the given RDD as a temporary table in the catalog. 112 113 Temporary tables exist only during the lifetime of this instance of 114 SQLContext. 115 116 >>> srdd = sqlCtx.inferSchema(rdd) 117 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 118 """ 119 if (rdd.__class__ is SchemaRDD): 120 jschema_rdd = rdd._jschema_rdd 121 self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) 122 else: 123 raise ValueError("Can only register SchemaRDD as table")
124
125 - def parquetFile(self, path):
126 """Loads a Parquet file, returning the result as a L{SchemaRDD}. 127 128 >>> import tempfile, shutil 129 >>> parquetFile = tempfile.mkdtemp() 130 >>> shutil.rmtree(parquetFile) 131 >>> srdd = sqlCtx.inferSchema(rdd) 132 >>> srdd.saveAsParquetFile(parquetFile) 133 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 134 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 135 True 136 """ 137 jschema_rdd = self._ssql_ctx.parquetFile(path) 138 return SchemaRDD(jschema_rdd, self)
139 140
141 - def jsonFile(self, path):
142 """Loads a text file storing one JSON object per line, 143 returning the result as a L{SchemaRDD}. 144 It goes through the entire dataset once to determine the schema. 145 146 >>> import tempfile, shutil 147 >>> jsonFile = tempfile.mkdtemp() 148 >>> shutil.rmtree(jsonFile) 149 >>> ofn = open(jsonFile, 'w') 150 >>> for json in jsonStrings: 151 ... print>>ofn, json 152 >>> ofn.close() 153 >>> srdd = sqlCtx.jsonFile(jsonFile) 154 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 155 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1") 156 >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}}, 157 ... {"f1": 2, "f2": "row2", "f3":{"field4":22}}, 158 ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}] 159 True 160 """ 161 jschema_rdd = self._ssql_ctx.jsonFile(path) 162 return SchemaRDD(jschema_rdd, self)
163
164 - def jsonRDD(self, rdd):
165 """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. 166 It goes through the entire dataset once to determine the schema. 167 168 >>> srdd = sqlCtx.jsonRDD(json) 169 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 170 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1") 171 >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}}, 172 ... {"f1": 2, "f2": "row2", "f3":{"field4":22}}, 173 ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}] 174 True 175 """ 176 def func(split, iterator): 177 for x in iterator: 178 if not isinstance(x, basestring): 179 x = unicode(x) 180 yield x.encode("utf-8")
181 keyed = PipelinedRDD(rdd, func) 182 keyed._bypass_serializer = True 183 jrdd = keyed._jrdd.map(self._jvm.BytesToString()) 184 jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) 185 return SchemaRDD(jschema_rdd, self)
186
187 - def sql(self, sqlQuery):
188 """Return a L{SchemaRDD} representing the result of the given query. 189 190 >>> srdd = sqlCtx.inferSchema(rdd) 191 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 192 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") 193 >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, 194 ... {"f1" : 3, "f2": "row3"}] 195 True 196 """ 197 return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
198
199 - def table(self, tableName):
200 """Returns the specified table as a L{SchemaRDD}. 201 202 >>> srdd = sqlCtx.inferSchema(rdd) 203 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 204 >>> srdd2 = sqlCtx.table("table1") 205 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 206 True 207 """ 208 return SchemaRDD(self._ssql_ctx.table(tableName), self)
209
210 - def cacheTable(self, tableName):
211 """Caches the specified table in-memory.""" 212 self._ssql_ctx.cacheTable(tableName)
213
214 - def uncacheTable(self, tableName):
215 """Removes the specified table from the in-memory cache.""" 216 self._ssql_ctx.uncacheTable(tableName)
217
218 219 -class HiveContext(SQLContext):
220 """A variant of Spark SQL that integrates with data stored in Hive. 221 222 Configuration for Hive is read from hive-site.xml on the classpath. 223 It supports running both SQL and HiveQL commands. 224 """ 225 226 @property
227 - def _ssql_ctx(self):
228 try: 229 if not hasattr(self, '_scala_HiveContext'): 230 self._scala_HiveContext = self._get_hive_ctx() 231 return self._scala_HiveContext 232 except Py4JError as e: 233 raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ 234 "sbt/sbt assembly" , e)
235
236 - def _get_hive_ctx(self):
237 return self._jvm.HiveContext(self._jsc.sc())
238
239 - def hiveql(self, hqlQuery):
240 """ 241 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 242 """ 243 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
244
245 - def hql(self, hqlQuery):
246 """ 247 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 248 """ 249 return self.hiveql(hqlQuery)
250
251 252 -class LocalHiveContext(HiveContext):
253 """Starts up an instance of hive where metadata is stored locally. 254 255 An in-process metadata data is created with data stored in ./metadata. 256 Warehouse data is stored in in ./warehouse. 257 258 >>> import os 259 >>> hiveCtx = LocalHiveContext(sc) 260 >>> try: 261 ... supress = hiveCtx.hql("DROP TABLE src") 262 ... except Exception: 263 ... pass 264 >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') 265 >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") 266 >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) 267 >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) 268 >>> num = results.count() 269 >>> reduce_sum = results.reduce(lambda x, y: x + y) 270 >>> num 271 500 272 >>> reduce_sum 273 130091 274 """ 275
276 - def _get_hive_ctx(self):
277 return self._jvm.LocalHiveContext(self._jsc.sc())
278
279 280 -class TestHiveContext(HiveContext):
281
282 - def _get_hive_ctx(self):
283 return self._jvm.TestHiveContext(self._jsc.sc())
284
285 286 # TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples 287 # are custom classes that must be generated per Schema. 288 -class Row(dict):
289 """A row in L{SchemaRDD}. 290 291 An extended L{dict} that takes a L{dict} in its constructor, and 292 exposes those items as fields. 293 294 >>> r = Row({"hello" : "world", "foo" : "bar"}) 295 >>> r.hello 296 'world' 297 >>> r.foo 298 'bar' 299 """ 300
301 - def __init__(self, d):
302 d.update(self.__dict__) 303 self.__dict__ = d 304 dict.__init__(self, d)
305
306 307 -class SchemaRDD(RDD):
308 """An RDD of L{Row} objects that has an associated schema. 309 310 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can 311 utilize the relational query api exposed by SparkSQL. 312 313 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the 314 L{SchemaRDD} is not operated on directly, as it's underlying 315 implementation is an RDD composed of Java objects. Instead it is 316 converted to a PythonRDD in the JVM, on which Python operations can 317 be done. 318 """ 319
320 - def __init__(self, jschema_rdd, sql_ctx):
321 self.sql_ctx = sql_ctx 322 self._sc = sql_ctx._sc 323 self._jschema_rdd = jschema_rdd 324 325 self.is_cached = False 326 self.is_checkpointed = False 327 self.ctx = self.sql_ctx._sc 328 self._jrdd_deserializer = self.ctx.serializer
329 330 @property
331 - def _jrdd(self):
332 """Lazy evaluation of PythonRDD object. 333 334 Only done when a user calls methods defined by the 335 L{pyspark.rdd.RDD} super class (map, filter, etc.). 336 """ 337 if not hasattr(self, '_lazy_jrdd'): 338 self._lazy_jrdd = self._toPython()._jrdd 339 return self._lazy_jrdd
340 341 @property
342 - def _id(self):
343 return self._jrdd.id()
344
345 - def saveAsParquetFile(self, path):
346 """Save the contents as a Parquet file, preserving the schema. 347 348 Files that are written out using this method can be read back in as 349 a SchemaRDD using the L{SQLContext.parquetFile} method. 350 351 >>> import tempfile, shutil 352 >>> parquetFile = tempfile.mkdtemp() 353 >>> shutil.rmtree(parquetFile) 354 >>> srdd = sqlCtx.inferSchema(rdd) 355 >>> srdd.saveAsParquetFile(parquetFile) 356 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 357 >>> sorted(srdd2.collect()) == sorted(srdd.collect()) 358 True 359 """ 360 self._jschema_rdd.saveAsParquetFile(path)
361
362 - def registerAsTable(self, name):
363 """Registers this RDD as a temporary table using the given name. 364 365 The lifetime of this temporary table is tied to the L{SQLContext} 366 that was used to create this SchemaRDD. 367 368 >>> srdd = sqlCtx.inferSchema(rdd) 369 >>> srdd.registerAsTable("test") 370 >>> srdd2 = sqlCtx.sql("select * from test") 371 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 372 True 373 """ 374 self._jschema_rdd.registerAsTable(name)
375
376 - def insertInto(self, tableName, overwrite = False):
377 """Inserts the contents of this SchemaRDD into the specified table. 378 379 Optionally overwriting any existing data. 380 """ 381 self._jschema_rdd.insertInto(tableName, overwrite)
382
383 - def saveAsTable(self, tableName):
384 """Creates a new table with the contents of this SchemaRDD.""" 385 self._jschema_rdd.saveAsTable(tableName)
386
387 - def schemaString(self):
388 """Returns the output schema in the tree format.""" 389 return self._jschema_rdd.schemaString()
390
391 - def printSchema(self):
392 """Prints out the schema in the tree format.""" 393 print self.schemaString()
394
395 - def count(self):
396 """Return the number of elements in this RDD. 397 398 Unlike the base RDD implementation of count, this implementation 399 leverages the query optimizer to compute the count on the SchemaRDD, 400 which supports features such as filter pushdown. 401 402 >>> srdd = sqlCtx.inferSchema(rdd) 403 >>> srdd.count() 404 3L 405 >>> srdd.count() == srdd.map(lambda x: x).count() 406 True 407 """ 408 return self._jschema_rdd.count()
409
410 - def _toPython(self):
411 # We have to import the Row class explicitly, so that the reference Pickler has is 412 # pyspark.sql.Row instead of __main__.Row 413 from pyspark.sql import Row 414 jrdd = self._jschema_rdd.javaToPython() 415 # TODO: This is inefficient, we should construct the Python Row object 416 # in Java land in the javaToPython function. May require a custom 417 # pickle serializer in Pyrolite 418 return RDD(jrdd, self._sc, BatchedSerializer( 419 PickleSerializer())).map(lambda d: Row(d))
420 421 # We override the default cache/persist/checkpoint behavior as we want to cache the underlying 422 # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class
423 - def cache(self):
424 self.is_cached = True 425 self._jschema_rdd.cache() 426 return self
427
428 - def persist(self, storageLevel):
429 self.is_cached = True 430 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) 431 self._jschema_rdd.persist(javaStorageLevel) 432 return self
433
434 - def unpersist(self):
435 self.is_cached = False 436 self._jschema_rdd.unpersist() 437 return self
438
439 - def checkpoint(self):
440 self.is_checkpointed = True 441 self._jschema_rdd.checkpoint()
442
443 - def isCheckpointed(self):
444 return self._jschema_rdd.isCheckpointed()
445
446 - def getCheckpointFile(self):
447 checkpointFile = self._jschema_rdd.getCheckpointFile() 448 if checkpointFile.isDefined(): 449 return checkpointFile.get() 450 else: 451 return None
452
453 - def coalesce(self, numPartitions, shuffle=False):
454 rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) 455 return SchemaRDD(rdd, self.sql_ctx)
456
457 - def distinct(self):
458 rdd = self._jschema_rdd.distinct() 459 return SchemaRDD(rdd, self.sql_ctx)
460
461 - def intersection(self, other):
462 if (other.__class__ is SchemaRDD): 463 rdd = self._jschema_rdd.intersection(other._jschema_rdd) 464 return SchemaRDD(rdd, self.sql_ctx) 465 else: 466 raise ValueError("Can only intersect with another SchemaRDD")
467
468 - def repartition(self, numPartitions):
469 rdd = self._jschema_rdd.repartition(numPartitions) 470 return SchemaRDD(rdd, self.sql_ctx)
471
472 - def subtract(self, other, numPartitions=None):
473 if (other.__class__ is SchemaRDD): 474 if numPartitions is None: 475 rdd = self._jschema_rdd.subtract(other._jschema_rdd) 476 else: 477 rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) 478 return SchemaRDD(rdd, self.sql_ctx) 479 else: 480 raise ValueError("Can only subtract another SchemaRDD")
481
482 -def _test():
483 import doctest 484 from array import array 485 from pyspark.context import SparkContext 486 globs = globals().copy() 487 # The small batch size here ensures that we see multiple batches, 488 # even in these small test examples: 489 sc = SparkContext('local[4]', 'PythonTest', batchSize=2) 490 globs['sc'] = sc 491 globs['sqlCtx'] = SQLContext(sc) 492 globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, 493 {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) 494 jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', 495 '{"field1" : 2, "field2": "row2", "field3":{"field4":22}}', 496 '{"field1" : 3, "field2": "row3", "field3":{"field4":33}}'] 497 globs['jsonStrings'] = jsonStrings 498 globs['json'] = sc.parallelize(jsonStrings) 499 globs['nestedRdd1'] = sc.parallelize([ 500 {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, 501 {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) 502 globs['nestedRdd2'] = sc.parallelize([ 503 {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, 504 {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) 505 (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) 506 globs['sc'].stop() 507 if failure_count: 508 exit(-1)
509 510 511 if __name__ == "__main__": 512 _test() 513