# -*- coding: utf-8 -*-
"""Module implements base class :class:`GraphApiTask` for data extraction Celery tasks.

:class:`GraphApiTask` defines Redis and MS Graph API clients as attributes to be shared
by all tasks within a Celery worker.
from datetime import datetime, timedelta
from functools import partial
from json import dumps as json_dumps
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Tuple

from celery import Task
from celery.utils.log import get_task_logger

import gevent
from gevent.pool import Group

from redis import BlockingConnectionPool, ConnectionPool, RedisError
from redis.client import Redis

from yaml import YAMLError, safe_load

from ms_graph_exporter.ms_graph.api import MsGraph
from ms_graph_exporter.ms_graph.response import MsGraphResponse

[docs]class GraphApiTask(Task): r"""Base class for MS Graph API data extraction Celery tasks. Allows to share instances of Redis connection pool :class:`BlockingConnectionPool <redis:redis.BlockingConnectionPool>` and :class:`~ms_graph_exporter.ms_graph.api.MsGraph` between all tasks executed within a single Celery worker. Attributes ---------- _config : :obj:`~typing.Dict` [:obj:`str`, :obj:`~typing.Any`] Dictionary holding configuration parameters that are initialized by the worker on startup. _logger : :obj:`~logging.Logger` Channel to be used for log output specific to the worker. _ms_graph : :obj:`~ms_graph_exporter.ms_graph.api.MsGraph` Ms Graph API client instance to be used for queries. _redis_pool : :obj:`Redis <redis:redis.BlockingConnectionPool>` Connection pool to be used by Redis clients for response data queuing. """ _config: Dict[str, Any] = { "graph_app_config": "instance/app_config.yaml", "graph_client_id": "", "graph_client_secret": "", "graph_greenlets_count": 10, "graph_page_size": 50, "graph_queue_backend": "redis", "graph_queue_type": "list", "graph_queue_key": "ms_graph_exporter", "graph_redis_url": "redis://localhost:6379?db=0", "graph_redis_pool_block": True, "graph_redis_pool_gevent_queue": True, "graph_redis_pool_max_connections": 15, "graph_redis_pool_timeout": 1, "graph_streams": 1, "graph_stream_frame": 60, "graph_tenant": "", "graph_timelag": 120, } _logger: Logger = get_task_logger(__name__) _ms_graph: Optional[MsGraph] = None _redis_conn_pool: Optional[ConnectionPool] = None _redis_client: Optional[Redis] = None _redis_methods_map: Dict[str, str] = {"list": "RPUSH", "channel": "PUBLISH"} """Map between Redis queue types and data pushing operation."""
[docs] @classmethod def config_update(cls, **options) -> None: """Update internal configuration. Accept arbitrary set of keyword arguments and update ``_config`` dict with all values matching keys defined there initially. Parameters ---------- **options Set of arbitrary keyword arguments used to update ``_config``. """ cls._logger.debug("[%s]: Update config from kwargs.", cls.__name__) config_update: Dict = {k: options[k] for k in options.keys() if "graph_" in k} cls._config.update(config_update) cls._logger.debug("[%s]: Final config: %s", cls.__name__, cls._config)
[docs] @classmethod def config_update_from_file(cls, config_file: str = None) -> None: """Update internal configuration from file. Read YAML ``config_file`` and update ``_config`` dict with all values matching keys defined there initially. Parameters ---------- config_file Path to YAML config file. If not defined, tries to load file mentioned in ``_config["graph_app_config"]`` """ if config_file is None: config_file = cls._config["graph_app_config"] if config_file != "": cls._logger.debug( "[%s]: Update config from file: %s", cls.__name__, config_file ) cls._config["graph_app_config"] = config_file config_yaml: Dict = {} with open(config_file, "r") as f: try: config_yaml = safe_load(f) except YAMLError: cls._logger.exception( "Exception loading config file '%s'", config_file ) if config_yaml: config_update = { "graph_{}".format(k): config_yaml.get(k, None) for k in config_yaml.keys() } cls._config.update(config_update) cls._logger.debug("[%s]: Final config: %s", cls.__name__, cls._config)
@property def graph(self) -> MsGraph: """Provide a shared instance of MS Graph API client.""" if self._ms_graph is None: self._ms_graph = MsGraph( client_id=self._config["graph_client_id"], client_secret=self._config["graph_client_secret"], tenant=self._config["graph_tenant"], ) return self._ms_graph
[docs] def _records_to_redis_naive(self, records: List[Any]) -> bool: """Push records to Redis. Push ``records`` into a Redis list or publish in a channel. Does not apply any optimizations. Yields control to Gevent Hub at each iteration. Parameters ---------- records List of records to store. Returns ------- bool Flag indicating success or failure of the operation. """ redis_client: Redis = self.redis_client queue_type: str = self._config["graph_queue_type"] queue_key: str = self._config["graph_queue_key"] try: redis_action = getattr( redis_client, self._redis_methods_map[queue_type].lower() ) for r in records: gevent.sleep() redis_action(queue_key, json_dumps(r)) except RedisError as e: self._logger.exception("Redis Exception: %s", str(e)) # noqa: G200 result = False else: result = True return result
[docs] def _records_to_redis_pipe(self, records: List[Any]) -> bool: """Push records to Redis with pipelining. Push ``records`` into a list or publish in a channel. Utilize pipelining to minimize connection overhead. Yields control to Gevent Hub at each iteration. Parameters ---------- records List of records to store. Returns ------- bool Flag indicating success or failure of the operation. """ redis_client: Redis = self.redis_client queue_type: str = self._config["graph_queue_type"] queue_key: str = self._config["graph_queue_key"] try: with redis_client.pipeline() as pipe: pipe.multi() redis_action = getattr( pipe, self._redis_methods_map[queue_type].lower() ) for r in records: gevent.sleep() redis_action(queue_key, json_dumps(r)) pipe.execute() except RedisError as e: self._logger.exception("Redis Exception: %s", str(e)) # noqa: G200 result = False else: result = True return result
@property def redis_client(self) -> Redis: """Provide an instance of Redis client.""" if self._redis_client is None: redis_client = Redis(connection_pool=self.redis_conn_pool) self._redis_client = redis_client self._logger.debug( "[%s]: Initialized Redis client: %s", self.__name__, self._redis_client ) return self._redis_client @property def redis_conn_pool(self) -> ConnectionPool: """Provide an instance of Redis connection pool.""" if self._redis_conn_pool is None: if self._config["graph_redis_pool_block"]: pool_class: Callable = BlockingConnectionPool else: pool_class = ConnectionPool if self._config["graph_redis_pool_gevent_queue"]: redis_conn_pool = pool_class().from_url( self._config["graph_redis_url"], decode_components=True, max_connections=self._config["graph_redis_pool_max_connections"], timeout=self._config["graph_redis_pool_timeout"], queue_class=gevent.queue.LifoQueue, ) else: redis_conn_pool = pool_class().from_url( self._config["graph_redis_url"], decode_components=True, max_connections=self._config["graph_redis_pool_max_connections"], timeout=self._config["graph_redis_pool_timeout"], ) self._redis_conn_pool = redis_conn_pool self._logger.debug( "[%s]: Initialized Redis connection pool: %s", self.__name__, self._redis_conn_pool, ) return self._redis_conn_pool
[docs] def redis_rm_queue(self) -> bool: """Delete Redis queue storing data records. Returns ------- bool Flag indicating success or failure of the operation. """ redis_client: Redis = self.redis_client queue_type: str = self._config["graph_queue_type"] queue_key: str = self._config["graph_queue_key"] try: redis_client.delete(queue_key) except RedisError as e: result: bool = False self._logger.exception( # noqa: G200 "Exception deleting Redis key: %s", str(e) ) else: result = True"Cleared %s key '%s'", queue_type.upper(), queue_key) return result
[docs] def task_fetch_slice( self, slice_start: datetime = None, slice_end: datetime = None ) -> MsGraphResponse: """Fetch a time slice of records. Request time-domain data slice ``[slice_start - slice_end]`` Parameters ---------- slice_start Beginning of the time-domain data slice. slice_end End of the time-domain data slice. """ result = self.graph.get_signins( timestamp_start=slice_start, timestamp_end=slice_end, page_size=self._config["graph_page_size"], ) return result
[docs] def task_get_time_slices( self, timestamp: datetime = None ) -> List[Tuple[datetime, datetime]]: """Calculate time slices from the point in time. Based on available configuration, generates a series of consecutive time slices (streams) going back in time from ``timestamp``. Parameters ---------- timestamp Point in time from which to calculate time slices. Microseconds are zeroed and ``timestamp`` is rounded up to a second. If not defined, uses current time. Note ---- Following ``_config`` options influence the result: * ``timelag`` - seconds to shift back from ``timestamp`` adjusting the whole time-frame for the series. * ``streams`` - Number of time slices to generate. * ``stream_frame`` - Size of each time slice in seconds. So, with the ``timestamp=...-07T22:02:53.123``, ``timelag=60``, ``streams=3`` and ``stream_frame=10`` following series is generated: .. code-block:: text 1. [...-07T22:01:44 - ...-07T22:01:53] 2. [...-07T22:01:34 - ...-07T22:01:43] 3. [...-07T22:01:24 - ...-07T22:01:33] """ total_streams: int = self._config["graph_streams"] t_now: datetime = ( timestamp.replace(microsecond=0) if timestamp is not None else datetime.utcnow().replace(microsecond=0) ) t_lag: timedelta = timedelta(seconds=self._config["graph_timelag"]) t_sec: timedelta = timedelta(seconds=1) t_delta: timedelta = timedelta(seconds=self._config["graph_stream_frame"]) frame_end: datetime = t_now - t_lag - t_sec frame_start: datetime = frame_end + t_sec - t_delta * total_streams "Split [%s - %s] into %s slices", frame_start.isoformat(), frame_end.isoformat(), total_streams, ) result: List[Tuple[datetime, datetime]] = [] for i in range(total_streams): slice_start: datetime = frame_end + t_sec - t_delta * (i + 1) slice_end: datetime = frame_end - t_delta * i result.append((slice_start, slice_end)) return result
[docs] def task_records_to_log(self, records: List[Any]) -> None: """Push records to log. Output ``records`` in the log. Parameters ---------- records List of records to queue. """ for r in records: "[%s]: %s - %s - %s/%s", r["createdDateTime"], r["userPrincipalName"], r["ipAddress"], r["location"]["countryOrRegion"], r["location"]["city"], )
[docs] def task_redis_push_single( self, records: List[Any], push_mode: str = "pipe" ) -> bool: """Push records to Redis in a single batch. Note ---- See the note for :meth:`task_redis_push_multi_spawn` for detailed explanation on parameter ``push_mode``. Parameters ---------- records List of records to queue. push_mode Specifies the data transfer strategy to use. Returns ------- bool Flag indicating success or failure of the operation """ result: bool = False push_method = getattr(self, "_records_to_redis_{}".format(push_mode)) result = push_method(records=records) return result
[docs] def task_redis_push_multi_spawn( self, records: List[Any], push_mode: str = "pipe" ) -> bool: """Push records to Redis in parallel. Split ``records`` list into smaller data sets to be pushed to Redis with multiple Greenlets. Note ---- Parameter ``push_mode`` defines which data transfer strategy to use by providing the suffix for specific ``_records_to_redis_*`` method to be called. Currently accepts ``naive`` or ``pipe``. Parameters ---------- records List of records to store in Redis. push_mode Specifies the data transfer strategy to use. Returns ------- bool Flag indicating success or failure of the operation """ # noqa: D202 def chunks(l: List, n: int): """Yield successive n-sized chunks from l.""" for i in range(0, len(l), n): yield l[i : i + n] # noqa: E203 result: bool = False batch_size: int = 0 records_len: int = len(records) greenlets_count: int = self._config["graph_greenlets_count"] if records_len > greenlets_count: batch_size = (((records_len << 1) // greenlets_count) + 1) >> 1 elif records_len > 0: batch_size = records_len if records_len: greenlets: Group = Group() # See redis_push_func_partial: Callable = partial( getattr(self, "_records_to_redis_{}".format(push_mode)) ) greenlet_list: List = [] for batch in chunks(records, batch_size): greenlet_list.append( greenlets.spawn(redis_push_func_partial, records=batch) ) greenlets.join(raise_error=True) greenlet_results = [g.value for g in greenlet_list] self._logger.debug( "%s Greenlets spawned: %s", len(greenlet_results), greenlet_results ) result = False if False in greenlet_results else True self._logger.debug("Summary result: %s", result) return result
[docs] def task_redis_push_multi_imap( self, records: List[Any], push_mode: str = "pipe" ) -> bool: """Push records to Redis in parallel. Split ``records`` list into smaller data sets to be pushed to Redis with multiple Greenlets. Parameter ``push_mode`` defines which data transfer strategy to use. Note ---- See the note for :meth:`task_redis_push_multi_spawn` for detailed explanation on parameter ``push_mode``. Parameters ---------- records List of records to store in Redis. push_mode Specifies the data transfer strategy to use. Returns ------- bool Flag indicating success or failure of the operation """ # noqa: D202 def chunks(l: List, n: int): """Yield successive n-sized chunks from l.""" for i in range(0, len(l), n): yield l[i : i + n] # noqa: E203 result: bool = False batch_size: int = 0 records_len: int = len(records) greenlets_count: int = self._config["graph_greenlets_count"] if records_len > greenlets_count: batch_size = (((records_len << 1) // greenlets_count) + 1) >> 1 elif records_len > 0: batch_size = records_len if batch_size: greenlets: Group = Group() redis_push_func_partial: Callable = partial( getattr(self, "_records_to_redis_{}".format(push_mode)) ) greenlet_results: List = [] for g in greenlets.imap_unordered( redis_push_func_partial, chunks(records, batch_size) ): greenlet_results.append(g) self._logger.debug( "%s Greenlets returned: %s", len(greenlet_results), greenlet_results ) result = False if False in greenlet_results else True self._logger.debug("Summary result: %s", result) return result