1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from pyspark.rdd import RDD, PipelinedRDD
19 from pyspark.serializers import BatchedSerializer, PickleSerializer
20
21 from py4j.protocol import Py4JError
22
23 __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
24
25
26 -class SQLContext:
27 """Main entry point for SparkSQL functionality.
28
29 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as
30 tables, execute SQL over tables, cache tables, and read parquet files.
31 """
32
33 - def __init__(self, sparkContext, sqlContext = None):
34 """Create a new SQLContext.
35
36 @param sparkContext: The SparkContext to wrap.
37
38 >>> srdd = sqlCtx.inferSchema(rdd)
39 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
40 Traceback (most recent call last):
41 ...
42 ValueError:...
43
44 >>> bad_rdd = sc.parallelize([1,2,3])
45 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
46 Traceback (most recent call last):
47 ...
48 ValueError:...
49
50 >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
51 ... "boolean" : True}])
52 >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
53 ... x.boolean))
54 >>> srdd.collect()[0]
55 (1, u'string', 1.0, 1, True)
56 """
57 self._sc = sparkContext
58 self._jsc = self._sc._jsc
59 self._jvm = self._sc._jvm
60 self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap
61
62 if sqlContext:
63 self._scala_SQLContext = sqlContext
64
65 @property
66 - def _ssql_ctx(self):
67 """Accessor for the JVM SparkSQL context.
68
69 Subclasses can override this property to provide their own
70 JVM Contexts.
71 """
72 if not hasattr(self, '_scala_SQLContext'):
73 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
74 return self._scala_SQLContext
75
76 - def inferSchema(self, rdd):
77 """Infer and apply a schema to an RDD of L{dict}s.
78
79 We peek at the first row of the RDD to determine the fields names
80 and types, and then use that to extract all the dictionaries. Nested
81 collections are supported, which include array, dict, list, set, and
82 tuple.
83
84 >>> srdd = sqlCtx.inferSchema(rdd)
85 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
86 ... {"field1" : 3, "field2": "row3"}]
87 True
88
89 >>> from array import array
90 >>> srdd = sqlCtx.inferSchema(nestedRdd1)
91 >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
92 ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
93 True
94
95 >>> srdd = sqlCtx.inferSchema(nestedRdd2)
96 >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
97 ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
98 True
99 """
100 if (rdd.__class__ is SchemaRDD):
101 raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
102 elif not isinstance(rdd.first(), dict):
103 raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" %
104 (SchemaRDD.__name__, rdd.first()))
105
106 jrdd = self._pythonToJavaMap(rdd._jrdd)
107 srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
108 return SchemaRDD(srdd, self)
109
110 - def registerRDDAsTable(self, rdd, tableName):
111 """Registers the given RDD as a temporary table in the catalog.
112
113 Temporary tables exist only during the lifetime of this instance of
114 SQLContext.
115
116 >>> srdd = sqlCtx.inferSchema(rdd)
117 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
118 """
119 if (rdd.__class__ is SchemaRDD):
120 jschema_rdd = rdd._jschema_rdd
121 self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
122 else:
123 raise ValueError("Can only register SchemaRDD as table")
124
125 - def parquetFile(self, path):
126 """Loads a Parquet file, returning the result as a L{SchemaRDD}.
127
128 >>> import tempfile, shutil
129 >>> parquetFile = tempfile.mkdtemp()
130 >>> shutil.rmtree(parquetFile)
131 >>> srdd = sqlCtx.inferSchema(rdd)
132 >>> srdd.saveAsParquetFile(parquetFile)
133 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
134 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
135 True
136 """
137 jschema_rdd = self._ssql_ctx.parquetFile(path)
138 return SchemaRDD(jschema_rdd, self)
139
140
141 - def jsonFile(self, path):
142 """Loads a text file storing one JSON object per line,
143 returning the result as a L{SchemaRDD}.
144 It goes through the entire dataset once to determine the schema.
145
146 >>> import tempfile, shutil
147 >>> jsonFile = tempfile.mkdtemp()
148 >>> shutil.rmtree(jsonFile)
149 >>> ofn = open(jsonFile, 'w')
150 >>> for json in jsonStrings:
151 ... print>>ofn, json
152 >>> ofn.close()
153 >>> srdd = sqlCtx.jsonFile(jsonFile)
154 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
155 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1")
156 >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}},
157 ... {"f1": 2, "f2": "row2", "f3":{"field4":22}},
158 ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}]
159 True
160 """
161 jschema_rdd = self._ssql_ctx.jsonFile(path)
162 return SchemaRDD(jschema_rdd, self)
163
164 - def jsonRDD(self, rdd):
165 """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}.
166 It goes through the entire dataset once to determine the schema.
167
168 >>> srdd = sqlCtx.jsonRDD(json)
169 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
170 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1")
171 >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}},
172 ... {"f1": 2, "f2": "row2", "f3":{"field4":22}},
173 ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}]
174 True
175 """
176 def func(split, iterator):
177 for x in iterator:
178 if not isinstance(x, basestring):
179 x = unicode(x)
180 yield x.encode("utf-8")
181 keyed = PipelinedRDD(rdd, func)
182 keyed._bypass_serializer = True
183 jrdd = keyed._jrdd.map(self._jvm.BytesToString())
184 jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
185 return SchemaRDD(jschema_rdd, self)
186
187 - def sql(self, sqlQuery):
188 """Return a L{SchemaRDD} representing the result of the given query.
189
190 >>> srdd = sqlCtx.inferSchema(rdd)
191 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
192 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
193 >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
194 ... {"f1" : 3, "f2": "row3"}]
195 True
196 """
197 return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
198
199 - def table(self, tableName):
200 """Returns the specified table as a L{SchemaRDD}.
201
202 >>> srdd = sqlCtx.inferSchema(rdd)
203 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
204 >>> srdd2 = sqlCtx.table("table1")
205 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
206 True
207 """
208 return SchemaRDD(self._ssql_ctx.table(tableName), self)
209
210 - def cacheTable(self, tableName):
211 """Caches the specified table in-memory."""
212 self._ssql_ctx.cacheTable(tableName)
213
214 - def uncacheTable(self, tableName):
215 """Removes the specified table from the in-memory cache."""
216 self._ssql_ctx.uncacheTable(tableName)
217
218
219 -class HiveContext(SQLContext):
220 """A variant of Spark SQL that integrates with data stored in Hive.
221
222 Configuration for Hive is read from hive-site.xml on the classpath.
223 It supports running both SQL and HiveQL commands.
224 """
225
226 @property
227 - def _ssql_ctx(self):
228 try:
229 if not hasattr(self, '_scala_HiveContext'):
230 self._scala_HiveContext = self._get_hive_ctx()
231 return self._scala_HiveContext
232 except Py4JError as e:
233 raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
234 "sbt/sbt assembly" , e)
235
236 - def _get_hive_ctx(self):
237 return self._jvm.HiveContext(self._jsc.sc())
238
239 - def hiveql(self, hqlQuery):
240 """
241 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
242 """
243 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
244
245 - def hql(self, hqlQuery):
246 """
247 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
248 """
249 return self.hiveql(hqlQuery)
250
251
252 -class LocalHiveContext(HiveContext):
253 """Starts up an instance of hive where metadata is stored locally.
254
255 An in-process metadata data is created with data stored in ./metadata.
256 Warehouse data is stored in in ./warehouse.
257
258 >>> import os
259 >>> hiveCtx = LocalHiveContext(sc)
260 >>> try:
261 ... supress = hiveCtx.hql("DROP TABLE src")
262 ... except Exception:
263 ... pass
264 >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt')
265 >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
266 >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1)
267 >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1]))
268 >>> num = results.count()
269 >>> reduce_sum = results.reduce(lambda x, y: x + y)
270 >>> num
271 500
272 >>> reduce_sum
273 130091
274 """
275
276 - def _get_hive_ctx(self):
277 return self._jvm.LocalHiveContext(self._jsc.sc())
278
279
280 -class TestHiveContext(HiveContext):
281
282 - def _get_hive_ctx(self):
283 return self._jvm.TestHiveContext(self._jsc.sc())
284
285
286
287
288 -class Row(dict):
289 """A row in L{SchemaRDD}.
290
291 An extended L{dict} that takes a L{dict} in its constructor, and
292 exposes those items as fields.
293
294 >>> r = Row({"hello" : "world", "foo" : "bar"})
295 >>> r.hello
296 'world'
297 >>> r.foo
298 'bar'
299 """
300
302 d.update(self.__dict__)
303 self.__dict__ = d
304 dict.__init__(self, d)
305
308 """An RDD of L{Row} objects that has an associated schema.
309
310 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
311 utilize the relational query api exposed by SparkSQL.
312
313 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
314 L{SchemaRDD} is not operated on directly, as it's underlying
315 implementation is an RDD composed of Java objects. Instead it is
316 converted to a PythonRDD in the JVM, on which Python operations can
317 be done.
318 """
319
320 - def __init__(self, jschema_rdd, sql_ctx):
321 self.sql_ctx = sql_ctx
322 self._sc = sql_ctx._sc
323 self._jschema_rdd = jschema_rdd
324
325 self.is_cached = False
326 self.is_checkpointed = False
327 self.ctx = self.sql_ctx._sc
328 self._jrdd_deserializer = self.ctx.serializer
329
330 @property
332 """Lazy evaluation of PythonRDD object.
333
334 Only done when a user calls methods defined by the
335 L{pyspark.rdd.RDD} super class (map, filter, etc.).
336 """
337 if not hasattr(self, '_lazy_jrdd'):
338 self._lazy_jrdd = self._toPython()._jrdd
339 return self._lazy_jrdd
340
341 @property
343 return self._jrdd.id()
344
346 """Save the contents as a Parquet file, preserving the schema.
347
348 Files that are written out using this method can be read back in as
349 a SchemaRDD using the L{SQLContext.parquetFile} method.
350
351 >>> import tempfile, shutil
352 >>> parquetFile = tempfile.mkdtemp()
353 >>> shutil.rmtree(parquetFile)
354 >>> srdd = sqlCtx.inferSchema(rdd)
355 >>> srdd.saveAsParquetFile(parquetFile)
356 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
357 >>> sorted(srdd2.collect()) == sorted(srdd.collect())
358 True
359 """
360 self._jschema_rdd.saveAsParquetFile(path)
361
363 """Registers this RDD as a temporary table using the given name.
364
365 The lifetime of this temporary table is tied to the L{SQLContext}
366 that was used to create this SchemaRDD.
367
368 >>> srdd = sqlCtx.inferSchema(rdd)
369 >>> srdd.registerAsTable("test")
370 >>> srdd2 = sqlCtx.sql("select * from test")
371 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
372 True
373 """
374 self._jschema_rdd.registerAsTable(name)
375
376 - def insertInto(self, tableName, overwrite = False):
377 """Inserts the contents of this SchemaRDD into the specified table.
378
379 Optionally overwriting any existing data.
380 """
381 self._jschema_rdd.insertInto(tableName, overwrite)
382
384 """Creates a new table with the contents of this SchemaRDD."""
385 self._jschema_rdd.saveAsTable(tableName)
386
388 """Returns the output schema in the tree format."""
389 return self._jschema_rdd.schemaString()
390
392 """Prints out the schema in the tree format."""
393 print self.schemaString()
394
396 """Return the number of elements in this RDD.
397
398 Unlike the base RDD implementation of count, this implementation
399 leverages the query optimizer to compute the count on the SchemaRDD,
400 which supports features such as filter pushdown.
401
402 >>> srdd = sqlCtx.inferSchema(rdd)
403 >>> srdd.count()
404 3L
405 >>> srdd.count() == srdd.map(lambda x: x).count()
406 True
407 """
408 return self._jschema_rdd.count()
409
420
421
422
424 self.is_cached = True
425 self._jschema_rdd.cache()
426 return self
427
429 self.is_cached = True
430 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
431 self._jschema_rdd.persist(javaStorageLevel)
432 return self
433
435 self.is_cached = False
436 self._jschema_rdd.unpersist()
437 return self
438
440 self.is_checkpointed = True
441 self._jschema_rdd.checkpoint()
442
445
447 checkpointFile = self._jschema_rdd.getCheckpointFile()
448 if checkpointFile.isDefined():
449 return checkpointFile.get()
450 else:
451 return None
452
453 - def coalesce(self, numPartitions, shuffle=False):
456
460
462 if (other.__class__ is SchemaRDD):
463 rdd = self._jschema_rdd.intersection(other._jschema_rdd)
464 return SchemaRDD(rdd, self.sql_ctx)
465 else:
466 raise ValueError("Can only intersect with another SchemaRDD")
467
471
472 - def subtract(self, other, numPartitions=None):
473 if (other.__class__ is SchemaRDD):
474 if numPartitions is None:
475 rdd = self._jschema_rdd.subtract(other._jschema_rdd)
476 else:
477 rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions)
478 return SchemaRDD(rdd, self.sql_ctx)
479 else:
480 raise ValueError("Can only subtract another SchemaRDD")
481
483 import doctest
484 from array import array
485 from pyspark.context import SparkContext
486 globs = globals().copy()
487
488
489 sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
490 globs['sc'] = sc
491 globs['sqlCtx'] = SQLContext(sc)
492 globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
493 {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
494 jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
495 '{"field1" : 2, "field2": "row2", "field3":{"field4":22}}',
496 '{"field1" : 3, "field2": "row3", "field3":{"field4":33}}']
497 globs['jsonStrings'] = jsonStrings
498 globs['json'] = sc.parallelize(jsonStrings)
499 globs['nestedRdd1'] = sc.parallelize([
500 {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
501 {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
502 globs['nestedRdd2'] = sc.parallelize([
503 {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
504 {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
505 (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
506 globs['sc'].stop()
507 if failure_count:
508 exit(-1)
509
510
511 if __name__ == "__main__":
512 _test()
513