# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import sys import select import struct import socketserver as SocketServer import threading from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union from pyspark.serializers import read_int, CPickleSerializer if TYPE_CHECKING: from pyspark._typing import SupportsIAdd # noqa: F401 import socketserver.BaseRequestHandler # type: ignore[import] __all__ = ["Accumulator", "AccumulatorParam"] T = TypeVar("T") U = TypeVar("U", bound="SupportsIAdd") pickleSer = CPickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. _accumulatorRegistry: Dict[int, "Accumulator"] = {} def _deserialize_accumulator( aid: int, zero_value: T, accum_param: "AccumulatorParam[T]" ) -> "Accumulator[T]": from pyspark.accumulators import _accumulatorRegistry # If this certain accumulator was deserialized, don't overwrite it. if aid in _accumulatorRegistry: return _accumulatorRegistry[aid] else: accum = Accumulator(aid, zero_value, accum_param) accum._deserialized = True _accumulatorRegistry[aid] = accum return accum [docs]class Accumulator(Generic[T]): """ A shared variable that can be accumulated, i.e., has a commutative and associative "add" operation. Worker tasks on a Spark cluster can add values to an Accumulator with the `+=` operator, but only the driver program is allowed to access its value, using `value`. Updates from the workers get propagated automatically to the driver program. While :class:`SparkContext` supports accumulators for primitive data types like :class:`int` and :class:`float`, users can also define accumulators for custom types by providing a custom :py:class:`AccumulatorParam` object. Refer to its doctest for an example. Examples -------- >>> a = sc.accumulator(1) >>> a.value 1 >>> a.value = 2 >>> a.value 2 >>> a += 5 >>> a.value 7 >>> sc.accumulator(1.0).value 1.0 >>> sc.accumulator(1j).value 1j >>> rdd = sc.parallelize([1,2,3]) >>> def f(x): ... global a ... a += x >>> rdd.foreach(f) >>> a.value 13 >>> b = sc.accumulator(0) >>> def g(x): ... b.add(x) >>> rdd.foreach(g) >>> b.value 6 >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError: ... >>> def h(x): ... global a ... a.value = 7 >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError: ... >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError: ... """ def __init__(self, aid: int, value: T, accum_param: "AccumulatorParam[T]"): """Create a new Accumulator with a given initial value and AccumulatorParam object""" from pyspark.accumulators import _accumulatorRegistry self.aid = aid self.accum_param = accum_param self._value = value self._deserialized = False _accumulatorRegistry[aid] = self def __reduce__( self, ) -> Tuple[ Callable[[int, T, "AccumulatorParam[T]"], "Accumulator[T]"], Tuple[int, T, "AccumulatorParam[T]"], ]: """Custom serialization; saves the zero value from our AccumulatorParam""" param = self.accum_param return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) @property def value(self) -> T: """Get the accumulator's value; only usable in driver program""" if self._deserialized: raise RuntimeError("Accumulator.value cannot be accessed inside tasks") return self._value @value.setter def value(self, value: T) -> None: """Sets the accumulator's value; only usable in driver program""" if self._deserialized: raise RuntimeError("Accumulator.value cannot be accessed inside tasks") self._value = value [docs] def add(self, term: T) -> None: """Adds a term to this accumulator's value""" self._value = self.accum_param.addInPlace(self._value, term) def __iadd__(self, term: T) -> "Accumulator[T]": """The += operator; adds a term to this accumulator's value""" self.add(term) return self def __str__(self) -> str: return str(self._value) def __repr__(self) -> str: return "Accumulator<id=%i, value=%s>" % (self.aid, self._value) [docs]class AccumulatorParam(Generic[T]): """ Helper object that defines how to accumulate values of a given type. Examples -------- >>> from pyspark.accumulators import AccumulatorParam >>> class VectorAccumulatorParam(AccumulatorParam): ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): ... for i in range(len(val1)): ... val1[i] += val2[i] ... return val1 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) >>> va.value [1.0, 2.0, 3.0] >>> def g(x): ... global va ... va += [x] * 3 >>> rdd = sc.parallelize([1,2,3]) >>> rdd.foreach(g) >>> va.value [7.0, 8.0, 9.0] """ [docs] def zero(self, value: T) -> T: """ Provide a "zero value" for the type, compatible in dimensions with the provided `value` (e.g., a zero vector) """ raise NotImplementedError [docs] def addInPlace(self, value1: T, value2: T) -> T: """ Add two values of the accumulator's data type, returning a new value; for efficiency, can also update `value1` in place and return it. """ raise NotImplementedError class AddingAccumulatorParam(AccumulatorParam[U]): """ An AccumulatorParam that uses the + operators to add values. Designed for simple types such as integers, floats, and lists. Requires the zero value for the underlying type as a parameter. """ def __init__(self, zero_value: U): self.zero_value = zero_value def zero(self, value: U) -> U: return self.zero_value def addInPlace(self, value1: U, value2: U) -> U: value1 += value2 # type: ignore[operator] return value1 # Singleton accumulator params for some standard types INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) # type: ignore[type-var] FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) # type: ignore[type-var] COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) # type: ignore[type-var] class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ This handler will keep polling updates from the same socket until the server is shutdown. """ def handle(self) -> None: from pyspark.accumulators import _accumulatorRegistry auth_token = self.server.auth_token # type: ignore[attr-defined] def poll(func: Callable[[], bool]) -> None: while not self.server.server_shutdown: # type: ignore[attr-defined] # Poll every 1 second for new data -- don't block in case of shutdown. r, _, _ = select.select([self.rfile], [], [], 1) if self.rfile in r: if func(): break def accum_updates() -> bool: num_updates = read_int(self.rfile) for _ in range(num_updates): (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) return False def authenticate_and_accum_updates() -> bool: received_token: Union[bytes, str] = self.rfile.read(len(auth_token)) if isinstance(received_token, bytes): received_token = received_token.decode("utf-8") if received_token == auth_token: accum_updates() # we've authenticated, we can break out of the first loop now return True else: raise ValueError( "The value of the provided token to the AccumulatorServer is not correct." ) # first we keep polling till we've received the authentication token poll(authenticate_and_accum_updates) # now we've authenticated, don't need to check for the token anymore poll(accum_updates) class AccumulatorServer(SocketServer.TCPServer): def __init__( self, server_address: Tuple[str, int], RequestHandlerClass: Type["socketserver.BaseRequestHandler"], auth_token: str, ): SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) self.auth_token = auth_token """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. """ server_shutdown = False def shutdown(self) -> None: self.server_shutdown = True SocketServer.TCPServer.shutdown(self) self.server_close() def _start_update_server(auth_token: str) -> AccumulatorServer: """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() return server if __name__ == "__main__": import doctest from pyspark.context import SparkContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs["sc"] = SparkContext("local", "test") (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs["sc"].stop() if failure_count: sys.exit(-1)