#!/usr/bin/env python3 import argparse import socket import sys import threading import psycopg2 import psycopg2.extras import queue import os import signal import json import logging # # Errors # class InvalidConfigError(Exception): pass class UnknownClusterError(Exception): pass class DuplicateMetricVersionError(Exception): pass class UnsupportedMetricVersionError(Exception): pass class RequestTimeoutError(Exception): pass class DBError(Exception): pass # # Logging # logger = None # Store the current log file since there is no public method to get the filename # from a FileHandler object current_log_file = None # Handler objects for easy adding/removing/modifying file_log_handler = None stderr_log_handler = None def init_logging(config): """ Initialize (or re-initialize/modify) logging """ global logger global current_log_file global file_log_handler global stderr_log_handler # Get the logger object logger = logging.getLogger(__name__) # Create a common formatter formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') # Set up or modify stderr logging if config.stderr_log_level != 'OFF': # Create and add the handler if it doesn't exist if stderr_log_handler is None: stderr_log_handler = logging.StreamHandler() logger.addHandler(stderr_log_handler) # Set the formatter stderr_log_handler.setFormatter(formatter) # Set the log level level = logging.getLevelName(config.stderr_log_level) stderr_log_handler.setLevel(level) else: if stderr_log_handler is not None: logger.removeHandler(stderr_log_handler) stderr_log_handler = None # Set up or modify file logging if config.file_log_level != 'OFF': # Checck if we're switching files if file_log_handler is not None and config.log_file != current_log_file: old_file_logger = file_log_handler file_log_handler = None else: old_file_logger = None # Create and add the handler if it doesn't exist if file_log_handler is None: file_log_handler = logging.FileHandler(config.log_file, encoding='utf-8') logger.addHandler(file_log_handler) current_log_file = config.log_file # Set the formatter file_log_handler.setFormatter(formatter) # Set the log level level = logging.getLevelName(config.file_log_level) file_log_handler.setLevel(level) # Remove the old handler if there was one if old_file_logger is not None: logger.removeHandler(old_file_logger) # Note where logs are being written print("Logging to {} ({})".format(config.log_file, config.file_log_level)) else: if file_log_handler is not None: logger.removeHandler(file_log_handler) file_log_handler = None # Set the log level for the logger itself levels = [] if stderr_log_handler is not None and stderr_log_handler.level != logging.NOTSET: levels.append(stderr_log_handler.level) if file_log_handler is not None and file_log_handler.level != logging.NOTSET: levels.append(file_log_handler.level) if len(levels) > 0: logger.setLevel(min(levels)) else: # If we have no handlers, just bump the level to the max logger.setLevel(logging.CRITICAL) # # PID file handling # def write_pid_file(pid_file): if pid_file is not None: with open(pid_file, 'w') as f: f.write("{}".format(os.getpid())) def read_pid_file(pid_file): if pid_file is None: raise RuntimeError("No PID file specified") with open(pid_file, 'r') as f: return int(f.read().strip()) def remove_pid_file(pid_file): if pid_file is not None: os.unlink(pid_file) # # Global flags for signal handler # running = True reload = False # # Signal handler # def signal_handler(sig, frame): """ Function for handling signals INT => Shot down TERM => Shut down QUIT => Shut down HUP => Reload """ # Restore the original handler signal.signal(signal.SIGINT, signal.default_int_handler) # Signal everything to shut down if sig in [ signal.SIGINT, signal.SIGTERM, signal.SIGQUIT ]: logger.info("Shutting down ...") global running running = False # Signal a reload elif sig == signal.SIGHUP: logger.info("Reloading config ...") global reload reload = True # # Classes # class Config: """ Agent configuration Note: The config is initially loaded before logging is configured, so be mindful about logging anything in this class. Params: args: (argparse.Namespace) Command line arguments read_metrics: (bool) Indicate if metrics should be parsed read_clusters: (bool) Indicate if cluster information should be parsed Exceptions: InvalidConfigError: Indicates an issue with the config file OSError: Thrown if there's an issue opening a config file ValueError: Thrown if there is an encoding error """ def __init__(self, args, read_metrics = True, read_clusters = True): # Set defaults self.pid_file = None # PID file self.ipc_socket = 'pgmon.sock' # IPC socket self.ipc_timeout = 10 # IPC communication timeout (s) self.request_timeout = 10 # Request processing timeout (s) self.request_queue_size = 100 # Max size of the request queue before it blocks self.request_queue_timeout = 2 # Max time to wait when queueing a request (s) self.worker_count = 4 # Number of worker threads self.stderr_log_level = 'INFO' # Log level for stderr logging (or 'off') self.file_log_level = 'INFO' # Log level for file logging (od 'off') self.log_file = 'pgmon.log' # Log file self.metrics = {} # Metrics self.clusters = {} # Known clusters # Check if we have a config file self.config_file = args.config # Read config if self.config_file is not None: self.read_file(self.config_file, read_metrics, read_clusters) # Override anything that was specified on the command line if args.pidfile is not None: self.pid_file = args.pidfile if args.logfile is not None: self.log_file = args.logfile if args.socket is not None: self.ipc_socket = args.socket def read_file(self, config_file, read_metrics, read_clusters): """ Read a config file, possibly skipping metrics and clusters to lighten the load on the agent. Params: config_file: (str) Path to the config file read_metrics: (bool) Indicate if metrics should be parsed read_clusters: (bool) Indicate if cluster information should be parsed Exceptions: InvalidConfigError: Indicates an issue with the config file """ try: with open(config_file, 'r') as f: for line in f: # Clean up any whitespace at either end line = line.strip() # Skip empty lines and comments if line.startswith('#'): continue elif line == '': continue # Separate the line into a key-value pair and clean up extra # white space that may have been around the '=' (key, value) = [x.strip() for x in line.split('=', 1)] if value is None: raise InvalidConfigError("{}: {}", config_file, line) # Handle each key appropriately if key == 'include': print("Including file: {}".format(value)) self.read_file(value, read_metrics, read_clusters) elif key == 'pid_file': self.pid_file = value elif key == 'ipc_socket': self.ipc_socket = value elif key == 'ipc_timeout': self.ipc_timeout = float(value) elif key == 'request_timeout': self.request_timeout = float(value) elif key == 'request_queue_size': self.request_queue_size = int(value) elif key == 'request_queue_timeout': self.request_queue_timeout = float(value) elif key == 'worker_count': self.worker_count = int(value) elif key == 'stderr_log_level': self.stderr_log_level = value.upper() elif key == 'file_log_level': self.file_log_level = value.upper() elif key == 'log_file': self.log_file = value elif key == 'cluster': if read_clusters: self.add_cluster(value) elif key == 'metric': if read_metrics: print("Adding metric: {}".format(value)) self.add_metric(value) else: raise InvalidConfigError("WARNING: Unknown config option: {}".format(key)) except OSError as e: raise InvalidConfigError("Failed to open/read config file: {}".format(e)) except ValueError as e: raise InvalidConfigError("Encoding error in config file: {}".format(e)) def add_cluster(self, cluster_def): """ Parse and add connection information about a cluster to the config. Each cluster line is of the format: :[address]:[port]:[dbname]:[user]:password The name is a unique, arbitrary identifier to associate this cluster with a request from the monitoring agent. The address can be an IP address, host name, or path the the directory where a PostgreSQL socket exists. The dbname field is the default database to connect to for metrics which don't specify a database. This is also used to identify the PostgreSQL version. Default values replace empty conponents, except for the name field: address: /var/run/postgresql port: 5432 dbname: postgres user: postgres password: None Params: cluster_def: (str) Cluster definition string Exceptions: InvalidConfigError: Thrown if the cluster entry is missing any fields or contains invalid content """ # Split up the fields try: (name, address, port, dbname, user, password) = cluster_def.split(':') except ValueError: raise InvalidConfigError("Missing fields in cluster definition: {}".format(cluster_def)) # Make sure we have a name if name == '': raise InvalidConfigError("Cluster must have a name: {}".format(cluster_def)) # Set defaults for anything that's blank if address == '': address = '/var/run/postgresql' if port == '': port = 5432 else: # Convert the port to a number here try: port = int(port) except ValueError: raise InvalidConfigError("Invalid port number: {}".format(port)) if dbname == '': dbname = 'postgres' if user == '': user = 'postgres' if password == '': password = None # Create and add the cluster object self.clusters[name] = Cluster(name, address, port, dbname, user, password) def add_metric(self, metric_def): """ Parse and add a metric or metric version to the config. Each metric definition is of the form: :[return type]:[PostgreSQL version]: The name is an identifier used to reference this metric from the monitoring server. The return type indicates how to format the results. Possible values are: value: A single value is returned (first column of first record) column: A list is returned containing the first column of all records set: All records are returned as a list of dicts The version field is the first version of PostgreSQL for which this query is valid. The sql field if either the SQL to execute, or a string of the form: file: where is the path to a file containing the SQL. In either case, the SQL can contain references to parameters that are to be passed to PostgreSQL. These are substituted using python's format command, so variable names should be enclosed in curly brackets like {foo} Params: metric_def: (str) Metric definition string Exceptions: InvalidConfigError: Thrown if the metric entry is missing any fields or contains invalid content """ # Split up the fields try: (name, ret_type, version, sql) = metric_def.split(':', 3) except ValueError: raise InvalidConfigError("Missing fields in metric definition: {}".format(metric_def)) # Make sure we have a name and some SQL if name == '': raise InvalidConfigError("Missing name for metric: {}".format(metric_def)) # An empty SQL query indicates a metric is not suported after a particular version if sql == '': sql = None # If the sql is in a separate file, read it in try: if sql.startswith('file:'): (_,path) = sql.split(':', 1) with open(path, 'r') as f: sql = f.read() except OSError as e: raise InvalidConfigError("Failed to open/read SQL file: {}".format(e)) except ValueError as e: raise InvalidConfigError("Encoding error in SQL file: {}".format(e)) # If no starting version is given, set it to 0 if version == '': version = 0 # Find or create the metric try: metric = self.metrics[name] except KeyError: metric = Metric(name, ret_type) self.metrics[name] = metric # Add what was given as a version of the metric metric.add_version(int(version), sql) def get_pg_version(self, cluster_name): """ Return the version of PostgreSQL running on the specified cluster. The version is cached after the first successful retrieval. Params: cluster_name: (str) The identifier for the cluster, as defined in the config file Returns: The version using PostgreSQL's integer format (Mmmpp or MM00mm) Exceptions: UnknownClusterError: Thrown if the named cluster is not defined DBError: A database error occurred """ # TODO: Move this out of the Config class. Possibly move the code # to the DB class, while storing the value in the Cluster object. # Find the cluster try: cluster = self.clusters[cluster_name] except KeyError: raise UnknownClusterError(cluster_name) # Query the cluster if we don't already know the version # TODO: Expire this value at some point to pick up upgrades. if cluster.pg_version is None: db = DB(self, cluster_name) cluster.pg_version = int(db.query('SHOW server_version_num')[0]['server_version_num']) db.close() # Return the version number return cluster.pg_version class Cluster: """ Connection information for a PostgreSQL cluster Params: name: A unique, arbitrary identifier to associate this cluster with a request from the monitoring agent. address: An IP address, host name, or path the the directory where a PostgreSQL socket exists. port: The database port number. dbname: Default database to connect to for metrics which don't specify a database. This is also used to identify the PostgreSQL version. user: The database user to connect as. password: The password to use when connecting to the database. Leave this blank to use either no password or a password stored in the executing user's ~/.pgpass file. """ def __init__(self, name, address, port, dbname, user, password): self.name = name self.address = address self.dbname = dbname self.port = port self.user = user self.password = password # Dynamically acquired PostgreSQL version self.pg_version = None class DB: """ Database access Params: config: The agent config cluster_name: The name of the cluster to connect to dbname: A database to connect to, or None to use the default Exceptions: UnknownClusterError: Thrown if the named cluster is not defined DBError: Thrown for any database related error """ def __init__(self, config, cluster_name, dbname=None): logger.debug("Creating connection to cluster: {}".format(cluster_name)) # Find the named cluster try: cluster = config.clusters[cluster_name] except KeyError: raise UnknownClusterError(cluster_name) # Use the default database if not given one if dbname is None: logger.debug("Using default database: {}".format(cluster.dbname)) dbname = cluster.dbname # Connect to the database try: self.conn = psycopg2.connect( host = cluster.address, port = cluster.port, user = cluster.user, password = cluster.password, dbname = dbname); except Exception as e: raise DBError("Failed to connect to the database: {}".format(e)) # Make the connection readonly and enable autocommit try: self.conn.set_session(readonly=True, autocommit=True) except Exception as e: raise DBError("Failed to set session parameters: {}".format(e)) def query(self, sql, args=[]): """ Execute a query in the database. Params: sql: (str) The SQL statement to execute args: (list) List of positional arguments for the query Exceptions: DBError: Thrown for any database related error """ logger.debug("Executuing query: {}".format(sql)) try: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql, args) return(cur.fetchall()) except Exception as e: raise DBError("Failed to execute query: ".format(e)) def close(self): logger.debug("Closing db connection") try: if self.conn is not None: self.conn.close() except psycopg2.Error as e: logger.warning("Caught an error when closing a connection: {}".format(e)) self.conn = None class Request: """ A metric request Params: cluster_name: (str) The name of the cluster the request is for metric_name: (str) The name of the metric to obtain args: (dict) Dictionary of arguments for the metric """ def __init__(self, cluster_name, metric_name, args = {}): self.cluster_name = cluster_name self.metric_name = metric_name self.args = args self.result = None # Add a lock to indicate when the request is complete self.complete = threading.Lock() self.complete.acquire() def set_result(self, result): """ Set the result for the metric, and release the lock that allows the result to be returned to the client. """ # Set the result value self.result = result # Release the lock self.complete.release() def get_result(self, timeout = -1): """ Retrieve the result for the metric. This will wait for a result to be available. If timeout is >0, a RequestTimeoutError exception will be thrown if the specified number of seconds elapses. Params: timeout: (float) Number of seconds to wait before timing out """ # Wait until the request has been completed if self.complete.acquire(timeout = timeout): return self.result else: raise RequestTimeoutError() class Metric: """ A metric The return_type parameter controls how the results will be formatted when returned to the client. Possible values are: value: Return a single value column: Resturn a list conprised of the first column of the query results set: Return a list of dictionaries representing the queried records Params: name: (str) The name to associate with the metric ret_type: (str) The return type of the metric (value, column, or set) """ def __init__(self, name, ret_type): self.name = name self.versions = {} self.type = ret_type # Place holders for the query cache self.cached = {} def add_version(self, pg_version, sql): """ Add a versioned query for the metric Params: pg_version: (int) The first version number for which this query applies, or 0 for any version sql: (str) The SQL to execute for the specified version """ # Check if we already have SQL for this version if pg_version in self.versions: raise DuplicateMetricVersionError # Add the versioned SQL self.versions[pg_version] = MetricVersion(sql) def get_version(self, pg_version): """ Get an apropriate SQL query for the viven PostgreSQL version Params: pg_version: (int) The version of PostgreSQL for which to find a query """ # Since we're usually going to keep asking for the same version(s), # we cache the metric version the first time we need it. if pg_version not in self.cached: self.cached[pg_version] = None # Search through the cache starting from the lowest version until # we find a supporting version for v in reversed(sorted(self.versions.keys())): if pg_version >= v: self.cached[pg_version] = self.versions[v] break # If we didn't find a query, or the query is None (ie: the metric is # no longer supported for this version), throw an exception if pg_version not in self.cached or self.cached[pg_version] is None: raise UnsupportedMetricVersionError # Return the cached version return self.cached[pg_version] class MetricVersion: """ A version of a metric The query is formatted using str.format with a dictionary of variables, so you can use tings like {table} in the template, then pass table=foo when retrieving the SQL and '{table}' will be replaced with 'foo'. Note: Only minimal SQL injection checking is done on this. Basically we just throw an error if any substitution value contains an apostrophe or a semicolon Params: sql: (str) The SQL query template """ def __init__(self, sql): self.sql = sql def get_sql(self, args): """ Return the SQL for this version after substituting the provided variables. Params: args: (dict) Dictionary of formatting substitutions to apply to the query template Exceptions: SQLInjectionError: if any of the values to be substituted contain an apostrophe (') or semicolon (;) """ # Check for possible SQL injection attacks for v in args.values(): if "'" in v or ';' in v: raise SQLInjectionError() # Format and return the SQL return self.sql.format(**args) class IPC: """ IPC handler for communication between the agent and server components Params: config: (Config) The agent configuration object mode: (str) Which side of the communication setup this is (agent|server) Exceptions: RuneimeError: if the IPC object's mode is invalid """ def __init__(self, config, mode): # Validate the mode if mode not in [ 'agent', 'server' ]: raise RuntimeError("Invalid IPC mode: {}".format(self.mode)) self.config = config self.mode = mode # Set up the connection try: self.initialize() except Exception as e: logger.debug("IPC Initialization error: {}".format(e)) raise def initialize(self): """ Initialize the socket Exceptions: OSError: if the socket file already exists and could not be removed RuntimeError: if the IPC object's mode is invalid """ logger.debug("Initializing IPC") if self.mode == 'server': # Try to clean up the socket if it exists try: logger.debug("Unlinking any former socket") os.unlink(self.config.ipc_socket) except OSError: logger.debug("Caught an exception unlinking socket") if os.path.exists(self.config.ipc_socket): logger.debug("Socket stilll exists") raise logger.debug("No socket to unlink") # Create the socket self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # Set a timeout for accepting connections to be able to catch signals self.socket.settimeout(1) # Bind the socket and start listening logger.debug("Binding socket: {}".format(self.config.ipc_socket)) self.socket.bind(self.config.ipc_socket) logger.debug("Listening on socket") self.socket.listen(1) elif self.mode == 'agent': # Connect to the socket self.conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # Set the timeout for the agent side of the connection self.conn.settimeout(self.config.ipc_timeout) else: raise RuntimeError("Invalid IPC mode: {}".format(self.mode)) logger.debug("IPC initialization complete") def connect(self): """ Establish a connection to a socket (agent mode) Establish and return a connection to the socket Returns: The connected socket object Exceptions: TimeoutError: if establishing the connection times out """ # Connect to the socket self.conn.connect(self.config.ipc_socket) # Return the socket from this IPC object return self.conn def accept(self): """ Accept a connection (server) Wait for, accept, and return a connection to the socket Returns: The accepted socket object Exceptions: TimeoutError: if no connection is accepted before the timeout """ # Accept the connection (conn, _) = self.socket.accept() # Set the timeout on the new socket object conn.settimeout(self.config.ipc_timeout) # Return the new socket object return conn def send(self, conn, msg): """ Send a message across the IPC channel The message is encoded in UTF-8, and prefixed with a 4 byte, big endian unsigned integer Params: conn: (Socket) the connection to use msg: (str) The message to send Exceptions: TimeoutError: if no connection times out before finishing sending the message OverflowError: if the size of the message exceeds what fits in a four byte, unsigned integer UnicodeError: if the message is not UTF-8 encodable """ # Encode the message as bytes msg_bytes = msg.encode("utf-8") # Get the byte length of the message msg_len = len(msg_bytes) # Encode the message length msg_len_bytes = msg_len.to_bytes(4, byteorder='big') # Send the size and message bytes conn.sendall(msg_len_bytes + msg_bytes) def recv(self, conn): """ Receive a message across the IPC channel The message is encoded in UTF-8, and prefixed with a 4 byte, big endian unsigned integer Params: conn: (Socket) the connection to use Returns: The received message Exceptions: TimeoutError: if no connection times out before finishing receiving the whole message UnicodeError: if the message is not UTF-8 decodable """ # Read at least the length buffer = [] msg_len = -4 while msg_len < 0: buffer += conn.recv(1024) msg_len = len(buffer) - 4 # Pull out the bytes for the length msg_len_bytes = buffer[:4] # Convert the size bytes to an unsigned integer msg_len = int.from_bytes(msg_len_bytes, byteorder='big') # Finish reading the message if we don't have all of it while len(buffer) < msg_len + 4: buffer += conn.recv(1024) # Decode and return the message return bytes(buffer[4:]).decode("utf-8") class Agent: """ The agent side of the connector This mode is entended to be called byt the monitoring agent and communicates with the server side of the connector. The key is a comma separated string formatted as: TODO Results are printed to stdout. Params: args: (argparse.Namespace) Command line arguments key: (str) The key indicating the requested metric """ @staticmethod def run(args, key): # Read the agent config config = Config(args, read_metrics = False, read_clusters = False) init_logging(config) # Connect to the IPC socket ipc = IPC(config, 'agent') try: conn = ipc.connect() # Send a request ipc.send(conn, key) # Wait for a response res = ipc.recv(conn) except Exception as e: print("IPC error: {}".format(e)) sys.exit(1) # Output the response print(res) class Server: """ The server side of the connector Params: args: (argparse.Namespace) Command line arguments Exceptions: """ def __init__(self, args): # Note the path to the config file so it can be reloaded self.config_file = args.config # Load config self.config = Config(args) # Write pid file # Note: we record the PID file here so it can't be changed with reload self.pid_file = self.config.pid_file write_pid_file(self.pid_file) # Initialize logging init_logging(self.config) # Set up the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGHUP, signal_handler) # Create reqest queue logger.debug("Creating request queue") self.req_queue = queue.Queue(self.config.request_queue_size) # Spawn worker threads logger.debug("Spawning worker threads") self.workers = self.spawn_workers(self.config) def spawn_workers(self, config): """ Spawn all worker threads to process requests Params: config: (Config) The agent config object """ logger.info("Spawning {} workers".format(config.worker_count)) # Spawn worker threads workers = [None] * config.worker_count for i in range(config.worker_count): workers[i] = Worker(config, self.req_queue) workers[i].start() logger.debug("Started thread #{} (tid={})".format(i, workers[i].native_id)) # Return the list of worker threads return workers def retire_workers(self, workers): """ Retire (terminate) all worker threads in the given list of threads Params: workers: (list) List of worker threads to part ways with """ logger.info("Retiring {} workers".format(len(workers))) # Inform the workers that their services are no longer required for worker in workers: worker.active = False # Wait for the workers to turn in their badges for worker in workers: worker.join() def reload_config(self): """ Reload the config and return a new config object, and spawn new worker threads Exceptions: InvalidConfigError: Indicates an issue with the config file OSError: Thrown if there's an issue opening a config file ValueError: Thrown if there is an encoding error when reading the config """ # Clear the reload flag global reload reload = False # Read new config new_config = Config(self.config_file) # Re-init logging in case settings changed init_logging(new_config) # Spawn new workers new_workers = self.spawn_workers(new_config) # Replace workers old_workers = self.workers self.workers = new_workers # Retire old workers self.retire_workers(old_workers) # Adjust other settings # TODO # Set the new config as the one the server will use self.config = new_config def run(self): """ Run the server's main loop """ logger.info("Server starting") # Listen on ipc socket ipc = IPC(self.config, 'server') # Enter the main loop while True: # Wait for a request connection try: logger.debug("Waiting for a connection") conn = ipc.accept() except socket.timeout: conn = None # See if we should exit if not running: break # Check if we're supposed to reload the config if reload: try: self.reload_config() except Exceptioin as e: logger.ERROR("Reload failed: {}".format(e)) # If we just timed out waiting for a request, go back to waiting if conn is None: continue # Get the request string (csv) try: key = ipc.recv(conn) except socket.timeout: # Handle timeouts when receiving the request logger.warning("IPC communication timeout receiving request") conn.close() continue # Parse ipc request (csv) logger.debug("Parsing request key: {}".format(key)) try: # Split the key into a cluster name, metric name, and list of # metric arguments parts = key.split(',', 2) cluster_name = parts[0] metric_name = parts[1] # Parse any metric arguments into a dictionary args_dict = {} if len(parts) > 2: for arg in parts[2].split(','): if arg != '': (k, v) = arg.split('=', 1) args_dict[k] = v except Exception as e: # Handle problems parsing the request into its elements logger.warning("Received invalid request '{}': {}".format(key, e)) ipc.send(conn, "ERROR: Invalid key") conn.close() continue # Create request object req = Request(cluster_name, metric_name, args_dict) # Queue the request try: self.req_queue.put(req, timeout=self.config.request_queue_timeout) except queue.Full: # Handle situations where the queue is full and we didn't get # a free slot before the configured timeout logger.warning("Failed to queue request, queue is full") req.set_result("ERROR: Enqueue timeout") continue # Spawn a thread to wait for the result r = Responder(self.config, ipc, conn, req) r.start() # Join worker threads self.retire_workers(self.workers) # Clean up the PID file remove_pid_file(self.pid_file) # Be polite logger.info("Good bye") # Gracefully shut down logging logging.shutdown() class Worker(threading.Thread): """ Worker thread that processes requests (ie: queries the database) Params: config: (Config) The agent config object queue: (Queue) The request queue the worker should pull requests from """ def __init__(self, config, queue): super(Worker, self).__init__() self.config = config self.queue = queue self.active = True def run(self): """ Main processing loop for a worker thread """ while True: # Wait for a request try: req = self.queue.get(timeout=1) except queue.Empty: req = None # Check if we're supposed to exit if not self.active: # If we got a request, try to put it back on the queue if req is not None: try: queue.put(req, timeout=1) logger.info("Requeued request at worker exit") except: logger.warning("Failed to requeue request at worker exit") req.set_result("ERROR: Failed to requeue at Worker exit") logger.info("Worker exiting: tid={}".format(self.native_id)) break # If we got here because we waited too long for a request, go back to waiting if req is None: continue # Find the requested metrtic try: metric = self.config.metrics[req.metric_name] except KeyError: req.set_result("ERROR: Unknown metric: {}".format(req.metric_name)) continue # Get the DB version try: pg_version = self.config.get_pg_version(req.cluster_name) except Exception as e: req.set_result("Failed to retrieve database version for cluster: {}".format(req.cluster_name)) logger.error("Failed to get Postgresql version for cluster {}: {}".format(req.cluster_name, e)) continue # Get the query to use try: mv = metric.get_version(pg_version) except UnsupportedMetricVersionError: # Handle unsuported metric versions req.set_result("Unsupported PosgreSQL version for metric") continue # Identify the database to query try: dbname = req.args['db'] except KeyError: dbname = None # Query the database try: db = DB(self.config, req.cluster_name, dbname) res = db.query(mv.get_sql(req.args)) db.close() except Exception as e: # Handle database errors logger.error("Database error: {}".format(e)) # Make sure the database connectin is closed (ie: if the query timed out) try: db.close() except: pass req.set_result("Failed to query database") continue # Set the result on the request if len(res) == 0: req.set_result("Empty result set") elif metric.type == 'value': req.set_result("{}".format(list(res[0].values())[0])) elif metric.type == 'row': req.set_result(json.dumps(res[0])) elif metric.type == 'column': req.set_result(json.dumps([list(r.values())[0] for r in res])) elif metric.type == 'set': req.set_result(json.dumps(res)) class Responder(threading.Thread): """ Thread responsible for replying to requests Params: config: (Config) The agent config object ipc: (IPC) The IPC object used for communication conn: (Socket) The connected socket to communicate with req: (Request) The request object to handle """ def __init__(self, config, ipc, conn, req): super(Responder, self).__init__() # Wait for a result result = req.get_result() # Send the result back to the client try: ipc.send(conn, result) except Exception as e: logger.warning("Failed to reply to agent: {}".format(e)) def main(): # Set up command line argument parser parser = argparse.ArgumentParser( prog='pgmon', description='A briidge between monitoring tools and PostgreSQL') # General options parser.add_argument('-c', '--config', default=None) parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('-s', '--socket', default=None) parser.add_argument('-l', '--logfile', default=None) parser.add_argument('-p', '--pidfile', default=None) # Operational mode parser.add_argument('-S', '--server', action='store_true') parser.add_argument('-r', '--reload', action='store_true') # Agent options parser.add_argument('key', nargs='?') # Parse command line arguments args = parser.parse_args() if args.server: # Try to start running in server mode try: server = Server(args) except Exception as e: sys.exit("Failed to start server: {}".format(e)) try: server.run() except Exception as e: sys.exit("Caught an unexpected runtime error: {}".format(e)) elif args.reload: # Try to signal a running server to reload its config try: config = Config(args, read_metrics = False) except Exception as e: sys.exit("Failed to read config file: {}".format(e)) # Read the PID file try: pid = read_pid_file(config.pid_file) except Exception as e: sys.exit("Failed to read PID file: {}".format(e)) # Signal the server to reload try: os.kill(pid, signal.SIGHUP) except: sys.exit("Failed to signal server: {}".format(e)) else: # Start running in agent mode Agent.run(args, args.key) if __name__ == '__main__': main() ## # Tests ## class TestConfig: pass class TestDB: pass class TestRequest: def test_request_creation(self): # Create result with no args req1 = Request('c1', 'foo', {}) assert req1.cluster_name == 'c1' assert req1.metric_name == 'foo' assert len(req1.args) == 0 assert req1.complete.locked() # Create result with args req2 = Request('c1', 'foo', {'arg1': 'value1', 'arg2': 'value2'}) assert req2.metric_name == 'foo' assert len(req2.args) == 2 assert req2.complete.locked() def test_request_lock(self): req1 = Request('c1', 'foo', {}) assert req1.complete.locked() req1.set_result('blah') assert not req1.complete.locked() assert 'blah' == req1.get_result() class TestMetric: def test_metric_creation(self): # Test basic creation m1 = Metric('foo', 'value') assert m1.name == 'foo' assert len(m1.versions) == 0 assert m1.type == 'value' assert m1.cached is None assert m1.cached_version is None def test_metric_add_version(self): # Test adding versions m1 = Metric('foo', 'value') assert len(m1.versions) == 0 m1.add_version(0, 'default') assert len(m1.versions) == 1 m1.add_version(120003, 'v12.3') assert len(m1.versions) == 2 # Make sure added versions are correct assert m1.versions[0].sql == 'default' assert m1.versions[120003].sql == 'v12.3' def test_metric_get_version(self): # Test retrieving metric versions m1 = Metric('foo', 'value') m1.add_version(100000, 'v10.0') m1.add_version(120000, 'v12.0') # Make sure cache is initially empty assert m1.cached is None assert m1.cached_version is None assert m1.get_version(110003).sql == 'v10.0' # Make sure cache is set assert m1.cached is not None assert m1.cached_version == 110003 # Make sure returned value changes with version assert m1.get_version(120000).sql == 'v12.0' assert m1.get_version(150005).sql == 'v12.0' # Make sure an error is thrown when no version matches with pytest.raises(UnsupportedMetricVersionError): m1.get_version(90603) # Add a default version m1.add_version(0, 'default') assert m1.get_version(90603).sql == 'default' assert m1.get_version(110003).sql == 'v10.0' assert m1.get_version(120000).sql == 'v12.0' assert m1.get_version(150005).sql == 'v12.0' class TestMetricVersion: def test_metric_version_creation(self): mv1 = MetricVersion('test') assert mv1.sql == 'test' def test_metric_version_templating(self): mv1 = MetricVersion('foo') assert mv1.get_sql({}) == 'foo' mv2 = MetricVersion('foo {a1} {a3} {a2}') assert mv2.get_sql({'a1': 'bar', 'a2': 'blah blah blah', 'a3': 'baz'}) == 'foo bar baz blah blah blah' class TestIPC: pass class TestAgent: pass class TestServer: pass class TestWorker: pass class TestResponder: pass