#!/usr/bin/env python3 import argparse import socket import sys import threading import psycopg2 import psycopg2.extras import queue import os import signal import json import logging # # Errors # class InvalidConfigError(Exception): pass class DuplicateMetricVersionError(Exception): pass class UnsupportedMetricVersionError(Exception): pass class RequestTimeoutError(Exception): pass # # Logging # logger = None # Store the current log file since there is no public method to get the filename # from a FileHandler object current_log_file = None # Handler objects for easy adding/removing/modifying file_log_handler = None stderr_log_handler = None def init_logging(config): """ Initialize (or re-initialize/modify) logging """ global logger global file_log_handler global stderr_log_handler # Get the logger object logger = logging.getLogger(__name__) # Create a common formatter formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') # Set up or modify stderr logging if config.stderr_log_level != 'OFF': # Create and add the handler if it doesn't exist if stderr_log_handler is None: stderr_log_handler = logging.StreamHandler() logger.addHandler(stderr_log_handler) # Set the formatter stderr_log_handler.setFormatter(formatter) # Set the log level level = logging.getLevelName(config.stderr_log_level) stderr_log_handler.setLevel(level) else: if stderr_log_handler is not None: logger.removeHandler(stderr_log_handler) stderr_log_handler = None # Set up or modify file logging if config.file_log_level != 'OFF': # Create and add the handler if it doesn't exist if file_log_handler is None: file_log_handler = logging.FileHandler(config.log_file, encoding='utf-8') logger.addHandler(file_log_handler) # Set the formatter file_log_handler.setFormatter(formatter) # Set the log level level = logging.getLevelName(config.file_log_level) file_log_handler.setLevel(level) # Note where logs are being written print("Logging to {}".format(config.log_file)) else: if file_log_handler is not None: logger.removeHandler(file_log_handler) file_log_handler = None # # PID file handling # def write_pid_file(pid_file): with open(pid_file, 'w') as f: f.write("{}".format(os.getpid())) def read_pid_file(pid_file): with open(pid_file, 'r') as f: return int(f.read().strip()) def remove_pid_file(pid_file): os.unlink(pid_file) # # Global flags for signal handler # running = True reload = False # # Signal handler # def signal_handler(sig, frame): # Restore the original handler signal.signal(signal.SIGINT, signal.default_int_handler) # Signal everything to shut down if sig in [ signal.SIGINT, signal.SIGTERM, signal.SIGQUIT ]: logger.info("Shutting down ...") global running running = False elif sig == signal.SIGHUP: logger.info("Reloading config ...") global reload reload = True # # Classes # class Config: """ Agent configuration """ def __init__(self, config_file, read_metrics = True): # Set defaults self.pid_file = '/tmp/pgmon.pid' # PID file self.ipc_socket = '/tmp/pgmon.sock' # IPC socket self.ipc_timeout = 10 # IPC communication timeout (s) self.request_timeout = 10 # Request processing timeout (s) self.request_queue_size = 100 # Max size of the request queue before it blocks self.request_queue_timeout = 2 # Max time to wait when queueing a request (s) self.worker_count = 4 # Number of worker threads self.stderr_log_level = 'INFO' # Log level for stderr logging (or 'off') self.file_log_level = 'INFO' # Log level for file logging (od 'off') self.log_file = 'pgmon.log' # Log file self.host = 'localhost' # DB host self.port = 5432 # DB port self.user = 'postgres' # DB user self.password = None # DB password self.dbname = 'postgres' # DB name self.metrics = {} # Metrics # Dynamic values self.pg_version = None # PostgreSQL version # Read config self.read_file(config_file, read_metrics) def read_file(self, config_file, read_metrics = True): with open(config_file, 'r') as f: for line in f: line = line.strip() if line.startswith('#'): continue elif line == '': continue (key, value) = line.split('=', 1) if value is None: raise InvalidConfigError("{}: {}", config_file, line) if key == 'include': self.read_file(value, read_metrics) elif key == 'pid_file': self.pid_file = value elif key == 'ipc_socket': self.ipc_socket = value elif key == 'ipc_timeout': self.ipc_timeout = float(value) elif key == 'request_timeout': self.request_timeout = float(value) elif key == 'request_queue_size': self.request_queue_size = int(value) elif key == 'request_queue_timeout': self.request_queue_timeout = float(value) elif key == 'worker_count': self.worker_count = int(value) elif key == 'stderr_log_level': self.stderr_log_level = value.upper() elif key == 'file_log_level': self.file_log_level = value.upper() elif key == 'log_file': self.log_file = value elif key == 'host': self.host = value elif key == 'port': self.port = int(value) elif key == 'user': self.user = value elif key == 'password': self.password = value elif key == 'dbname': self.dbname = value elif key == 'metric': if read_metrics: self.add_metric(value) else: raise InvalidConfigError("WARNING: Unknown config: {}".format(key)) def add_metric(self, metric_def): (name, ret_type, version, sql) = metric_def.split(':', 3) if sql is None: raise InvalidConfigError if sql.startswith('file:'): (_,path) = sql.split(':', 1) with open(path, 'r') as f: sql = f.read() if version == '': version = 0 try: metric = self.metrics[name] except KeyError: metric = Metric(name, ret_type) self.metrics[name] = metric metric.add_version(int(version), sql) def get_pg_version(self): if self.pg_version is None: db = DB(self) self.pg_version = int(db.query('SHOW server_version_num')[0]['server_version_num']) db.close() return self.pg_version class DB: """ Database access """ def __init__(self, config, dbname=None): if dbname is None: dbname = config.dbname self.conn = psycopg2.connect( host = config.host, port = config.port, user = config.user, password = config.password, dbname = dbname); self.conn.set_session(readonly=True, autocommit=True) def query(self, sql, args=[]): with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql, args) return(cur.fetchall()) def close(self): try: if self.conn is not None: self.conn.close() except psycopg2.Error as e: logger.warning("Caught an error when closing a connection: {}".format(e)) self.conn = None class Request: """ A metric request """ def __init__(self, metric_name, args = {}): self.metric_name = metric_name self.args = args self.result = None # Add a lock to indicate when the request is complete self.complete = threading.Lock() self.complete.acquire() def set_result(self, result): self.result = result # Release the lock self.complete.release() def get_result(self, timeout = -1): # Wait until the request has been completed if self.complete.acquire(timeout = timeout): return self.result else: raise RequestTimeoutError class Metric: """ A metric """ def __init__(self, name, ret_type): self.name = name self.versions = {} self.type = ret_type self.cached = None self.cached_version = None def add_version(self, version, sql): if version in self.versions: raise DuplicateMetricVersionError self.versions[version] = MetricVersion(sql) def get_version(self, version): # Since we're usually going to keep asking for the same version, # we cache the metric version and only search through the versions # again if the PostgreSQL version changes if self.cached is None or self.cached_version != version: self.cached = None self.cached_version = None for v in reversed(sorted(self.versions.keys())): if version >= v: self.cached = self.versions[v] self.cached_version = version break if self.cached is None: raise UnsupportedMetricVersionError return self.cached class MetricVersion: """ A version of a metric """ def __init__(self, sql): self.sql = sql def get_sql(self, args): return self.sql.format(**args) class IPC: """ IPC handler """ def __init__(self, config, mode): self.config = config self.mode = mode self.reconnect() def reconnect(self): if self.mode == 'server': # Try to clean up the socket if it exists try: os.unlink(self.config.ipc_socket) except OSError: if os.path.exists(self.config.ipc_socket): raise # Create the socket self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # Set a timeout for accepting connections to be able to catch signals self.socket.settimeout(1) self.socket.bind(self.config.ipc_socket) self.socket.listen(1) elif self.mode == 'agent': # Connect to the socket self.conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.conn.settimeout(self.config.ipc_timeout) else: raise RuntimeError def connect(self): """ Establish a connection to a socket """ self.conn.connect(self.config.ipc_socket) return self.conn def accept(self): """ Accept a connection """ (conn, _) = self.socket.accept() conn.settimeout(self.config.ipc_timeout) return conn def send(self, conn, msg): msg_bytes = msg.encode("utf-8") msg_len = len(msg_bytes) msg_len_bytes = msg_len.to_bytes(4, byteorder='big') conn.sendall(msg_len_bytes + msg_bytes) def recv(self, conn, timeout=-1): # Read at least the length buffer = [] msg_len = -4 while msg_len < 0: buffer += conn.recv(1024) msg_len = len(buffer) - 4 msg_len_bytes = buffer[:4] msg_len = int.from_bytes(msg_len_bytes, byteorder='big') while len(buffer) < msg_len + 4: buffer += conn.recv(1024) return bytes(buffer[4:]).decode("utf-8") class Agent: """ The agent side of the connector """ @staticmethod def run(config_file, key): config = Config(config_file, read_metrics = False) # Connect to the socket ipc = IPC(config, 'agent') try: conn = ipc.connect() # Send a request ipc.send(conn, key) # Wait for a response res = ipc.recv(conn) except: print("IPC error") sys.exit(1) # Output the response print(res) class Server: """ The server side of the connector """ def __init__(self, config_file): self.config_file = config_file # Load config self.config = Config(config_file) # Write pid file # Note: we record the PID file here so it can't be changed with reload self.pid_file = self.config.pid_file write_pid_file(self.pid_file) # Initialize logging init_logging(self.config) # Set up the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGHUP, signal_handler) # Create reqest queue self.req_queue = queue.Queue(self.config.request_queue_size) # Spawn worker threads self.workers = self.spawn_workers(self.config) def spawn_workers(self, config): logger.info("Spawning {} workers".format(config.worker_count)) workers = [None] * config.worker_count for i in range(config.worker_count): workers[i] = Worker(config, self.req_queue) workers[i].start() return workers def retire_workers(self, workers): logger.info("Retiring {} workers".format(len(workers))) for worker in workers: worker.active = False for worker in workers: worker.join() def reload_config(self): # Clear the reload flag global reload reload = False # Read new config new_config = Config(self.config_file) # Re-init logging in case settings changed init_logging(new_config) # Spawn new workers new_workers = self.spawn_workers(new_config) # Replace workers old_workers = self.workers self.workers = new_workers # Retire old workers self.retire_workers(old_workers) # Adjust other settings # TODO self.config = new_config def run(self): logger.info("Server starting") # Listen on ipc socket ipc = IPC(self.config, 'server') while True: # Wait for a request connection try: conn = ipc.accept() except socket.timeout: conn = None # See if we should exit if not running: break # Check if we're supposed to reload the config if reload: try: self.reload_config() except Exceptioin as e: logger.ERROR("Reload failed: {}".format(e)) # If we just timed out waiting for a request, go back to waiting if conn is None: continue # Get the request string (csv) key = ipc.recv(conn) # Parse ipc request (csv) try: parts = key.split(',', 1) metric_name = parts[0] args_dict = {} if len(parts) > 1: for arg in parts[1].split(','): if arg != '': (k, v) = arg.split('=', 1) args_dict[k] = v except socket.timeout: logger.warning("IPC communication timeout receiving request") conn.close() continue except Exception: logger.warning("Received invalid request: '{}'".format(key)) ipc.send(conn, "ERROR: Invalid key") conn.close() continue # Create request object req = Request(metric_name, args_dict) # Queue the request try: self.req_queue.put(req, timeout=self.config.request_queue_timeout) except: logger.warning("Failed to queue request") req.set_result("ERROR: Queue timeout") continue # Spawn a thread to wait for the result r = Responder(self.config, ipc, conn, req) r.start() # Join worker threads self.retire_workers(self.workers) # Clean up the PID file remove_pid_file(self.pid_file) logger.info("Good bye") class Worker(threading.Thread): """ Worker thread that processes requests (ie: queries the database) """ def __init__(self, config, queue): super(Worker, self).__init__() self.config = config self.queue = queue self.active = True def run(self): while True: # Wait for a request try: req = self.queue.get(timeout=1) except queue.Empty: req = None # Check if we're supposed to exit if not self.active: logger.info("Worker exiting: tid={}".format(self.native_id)) break # If we got here because we waited too long for a request, go back to waiting if req is None: continue # Find the requested metrtic try: metric = self.config.metrics[req.metric_name] except KeyError: req.set_result("ERROR: Unknown key '{}'".format(req.metric_name)) continue # Get the DB version try: pg_version = self.config.get_pg_version() except Exception as e: req.set_result("Failed to retrieve database version") logger.error("Failed to get Postgresql version: {}".format(e)) continue # Get the query to use try: mv = metric.get_version(pg_version) except UnsupportedMetricVersionError: req.set_result("Unsupported PosgreSQL version for metric") continue # Query the database try: dbname = req.args['db'] except KeyError: dbname = None try: db = DB(self.config, dbname) res = db.query(mv.get_sql(req.args)) db.close() except Exception as e: # Make sure the database connectin is closed (ie: if the query timed out) try: db.close() except: pass req.set_result("Failed to query database") logger.error("Database error: {}".format(e)) continue # Set the result on the request if len(res) == 0: req.set_result("Empty result set") elif metric.type == 'value': req.set_result("{}".format(list(res[0].values())[0])) elif metric.type == 'row': req.set_result(json.dumps(res[0])) elif metric.type == 'column': req.set_result(json.dumps([list(r.values())[0] for r in res])) elif metric.type == 'set': req.set_result(json.dumps(res)) try: db.close() except Exception as e: logger.debug("Failed to close database connection: {}".format(e)) class Responder(threading.Thread): """ Thread responsible for replying to requests """ def __init__(self, config, ipc, conn, req): super(Responder, self).__init__() # Wait for a result result = req.get_result() # Send the result back to the client ipc.send(conn, result) def main(): parser = argparse.ArgumentParser( prog='pgmon', description='A briidge between monitoring tools and PostgreSQL') # General options parser.add_argument('-c', '--config', default='pgmon.cfg') parser.add_argument('-v', '--verbose', action='store_true') # Operational mode parser.add_argument('-s', '--server', action='store_true') parser.add_argument('-r', '--reload', action='store_true') # Agent options parser.add_argument('key', nargs='?') args = parser.parse_args() if args.server: server = Server(args.config) server.run() elif args.reload: config = Config(args.config, read_metrics = False) pid = read_pid_file(config.pid_file) os.kill(pid, signal.SIGHUP) else: Agent.run(args.config, args.key) if __name__ == '__main__': main() ## # Tests ## class TestConfig: pass class TestDB: pass class TestRequest: def test_request_creation(self): # Create result with no args req1 = Request('foo', {}) assert req1.metric_name == 'foo' assert len(req1.args) == 0 assert req1.complete.locked() # Create result with args req2 = Request('foo', {'arg1': 'value1', 'arg2': 'value2'}) assert req2.metric_name == 'foo' assert len(req2.args) == 2 assert req2.complete.locked() def test_request_lock(self): req1 = Request('foo', {}) assert req1.complete.locked() req1.set_result('blah') assert not req1.complete.locked() assert 'blah' == req1.get_result() class TestMetric: def test_metric_creation(self): # Test basic creation m1 = Metric('foo', 'value') assert m1.name == 'foo' assert len(m1.versions) == 0 assert m1.type == 'value' assert m1.cached is None assert m1.cached_version is None def test_metric_add_version(self): # Test adding versions m1 = Metric('foo', 'value') assert len(m1.versions) == 0 m1.add_version(0, 'default') assert len(m1.versions) == 1 m1.add_version(120003, 'v12.3') assert len(m1.versions) == 2 # Make sure added versions are correct assert m1.versions[0].sql == 'default' assert m1.versions[120003].sql == 'v12.3' def test_metric_get_version(self): # Test retrieving metric versions m1 = Metric('foo', 'value') m1.add_version(100000, 'v10.0') m1.add_version(120000, 'v12.0') # Make sure cache is initially empty assert m1.cached is None assert m1.cached_version is None assert m1.get_version(110003).sql == 'v10.0' # Make sure cache is set assert m1.cached is not None assert m1.cached_version == 110003 # Make sure returned value changes with version assert m1.get_version(120000).sql == 'v12.0' assert m1.get_version(150005).sql == 'v12.0' # Make sure an error is thrown when no version matches with pytest.raises(UnsupportedMetricVersionError): m1.get_version(90603) # Add a default version m1.add_version(0, 'default') assert m1.get_version(90603).sql == 'default' assert m1.get_version(110003).sql == 'v10.0' assert m1.get_version(120000).sql == 'v12.0' assert m1.get_version(150005).sql == 'v12.0' class TestMetricVersion: def test_metric_version_creation(self): mv1 = MetricVersion('test') assert mv1.sql == 'test' def test_metric_version_templating(self): mv1 = MetricVersion('foo') assert mv1.get_sql({}) == 'foo' mv2 = MetricVersion('foo {a1} {a3} {a2}') assert mv2.get_sql({'a1': 'bar', 'a2': 'blah blah blah', 'a3': 'baz'}) == 'foo bar baz blah blah blah' class TestIPC: pass class TestAgent: pass class TestServer: pass class TestWorker: pass class TestResponder: pass