From 3abbc1d83b30a8c3a747cfa6305c7f9550788d31 Mon Sep 17 00:00:00 2001 From: James Campbell Date: Thu, 16 May 2024 11:41:47 -0400 Subject: [PATCH] Initial implementation --- pgmon.cfg | 5 + pgmon.py | 676 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 681 insertions(+) create mode 100644 pgmon.cfg create mode 100755 pgmon.py diff --git a/pgmon.cfg b/pgmon.cfg new file mode 100644 index 0000000..67da094 --- /dev/null +++ b/pgmon.cfg @@ -0,0 +1,5 @@ +host=localhost +port=54315 +metric=max_frozen_age:value::SELECT max(age(datfrozenxid)) FROM pg_database +metric=db_stats:row::SELECT * FROM pg_stat_database WHERE datname = '{datname}' +metric=discover_dbs:column::SELECT datname FROM pg_database diff --git a/pgmon.py b/pgmon.py new file mode 100755 index 0000000..e777b3e --- /dev/null +++ b/pgmon.py @@ -0,0 +1,676 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import socket +import sys +import threading + +if "pytest" in sys.modules: + # Conditional modules are needed for tests, so import them if this is a test + import pytest + + @pytest.fixture(scope="session", autouse=True) + def pytest_imports(): + import psycopg2 + import psycopg2.extras + import queue + import os + import signal + import json + +# +# Errors +# + +class InvalidConfigError(Exception): + pass +class DuplicateMetricVersionError(Exception): + pass +class UnsupportedMetricVersionError(Exception): + pass +class RequestTimeoutError(Exception): + pass + +# +# Global variables +# +running = True + +# +# Signal handler +# +def signal_handler(sig, frame): + # Restore the original handler + signal.signal(signal.SIGINT, signal.default_int_handler) + + # Signal everything to shut down + if sig == signal.SIGINT: + print("Shutting down ...") + global running + running = False + +# +# Classes +# + +class Config: + """ + Agent configuration + """ + def __init__(self, config_file, read_metrics = True): + # Set defaults + self.ipc_socket = '/tmp/pgmon.sock' # IPC socket + self.ipc_timeout = 10 # IPC communication timeout (s) + self.request_timeout = 10 # Request processing timeout (s) + self.request_queue_size = 100 # Max size of the request queue before it blocks + self.request_queue_timeout = 2 # Max time to wait when queueing a request (s) + self.worker_count = 4 # Number of worker threads + + self.host = 'localhost' # DB host + self.port = 5432 # DB port + self.user = 'postgres' # DB user + self.password = None # DB password + self.dbname = 'postgres' # DB name + + self.metrics = {} # Metrics + + # Dynamic values + self.pg_version = None # PostgreSQL version + + # Read config + self.read_file(config_file, read_metrics) + + def read_file(self, config_file, read_metrics = True): + with open(config_file, 'r') as f: + for line in f: + line = line.strip() + + if line.startswith('#'): + continue + elif line == '': + continue + + (key, value) = line.split('=', 1) + if value is None: + raise InvalidConfigError("{}: {}", config_file, line) + + if key == 'include': + self.read_file(value, read_metric) + elif key == 'socket': + self.ipc_socket = value + elif key == 'socket_timeout': + self.ipc_timeout = float(value) + elif key == 'request_timeout': + self.request_timeout = float(value) + elif key == 'request_queue_size': + self.request_queue_size = int(value) + elif key == 'request_queue_timeout': + self.request_queue_timeout = float(value) + elif key == 'worker_count': + self.worker_count = int(value) + elif key == 'host': + self.host = value + elif key == 'port': + self.port = int(value) + elif key == 'user': + self.user = value + elif key == 'password': + self.password = value + elif key == 'dbname': + self.dbname = value + elif key == 'metric': + if read_metrics: + self.add_metric(value) + + def add_metric(self, metric_def): + (name, ret_type, version, sql) = metric_def.split(':', 3) + if sql is None: + raise InvalidConfigError + + if sql.startswith('file:'): + (_,path) = sql.split(':', 1) + with open(path, 'r') as f: + sql = f.read() + + if version == '': + version = 0 + + try: + metric = self.metrics[name] + except KeyError: + metric = Metric(name, ret_type) + self.metrics[name] = metric + + metric.add_version(int(version), sql) + + def get_pg_version(self): + if self.pg_version is None: + db = DB(self) + self.pg_version = int(db.query('SHOW server_version_num')[0]['server_version_num']) + db.close() + return self.pg_version + + +class DB: + """ + Database access + """ + def __init__(self, config, dbname=None): + if dbname is None: + dbname = config.dbname + self.conn = psycopg2.connect( + host = config.host, + port = config.port, + user = config.user, + password = config.password, + dbname = dbname); + self.conn.set_session(readonly=True, autocommit=True) + + def query(self, sql, args=[]): + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute(sql, args) + return(cur.fetchall()) + + def close(self): + try: + if self.conn is not None: + self.conn.close() + except psycopg2.Error as e: + print("Caught an error when closing a connection: {}".format(e)) + self.conn = None + + +class Request: + """ + A metric request + """ + def __init__(self, key, args = {}): + self.key = key + self.args = args + self.result = None + + # Add a lock to indicate when the request is complete + self.complete = threading.Lock() + self.complete.acquire() + + def set_result(self, result): + self.result = result + + # Release the lock + self.complete.release() + + def get_result(self, timeout = -1): + # Wait until the request has been completed + if self.complete.acquire(timeout = timeout): + return self.result + else: + raise RequestTimeoutError + + +class Metric: + """ + A metric + """ + def __init__(self, name, ret_type): + self.name = name + self.versions = {} + self.type = ret_type + self.cached = None + self.cached_version = None + + def add_version(self, version, sql): + if version in self.versions: + raise DuplicateMetricVersionError + self.versions[version] = MetricVersion(sql) + + def get_version(self, version): + # Since we're usually going to keep asking for the same version, + # we cache the metric version and only search through the versions + # again if the PostgreSQL version changes + if self.cached is None or self.cached_version != version: + self.cached = None + self.cached_version = None + for v in reversed(sorted(self.versions.keys())): + if version >= v: + self.cached = self.versions[v] + self.cached_version = version + break + if self.cached is None: + raise UnsupportedMetricVersionError + + return self.cached + + +class MetricVersion: + """ + A version of a metric + """ + def __init__(self, sql): + self.sql = sql + + def get_sql(self, args): + return self.sql.format(**args) + + +class IPC: + """ + IPC handler + """ + def __init__(self, config, mode): + self.config = config + self.mode = mode + self.reconnect() + + def reconnect(self): + if self.mode == 'server': + # Try to clean up the socket if it exists + try: + os.unlink(self.config.ipc_socket) + except OSError: + if os.path.exists(self.config.ipc_socket): + raise + + # Create the socket + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + # Set a timeout for accepting connections to be able to catch signals + self.socket.settimeout(1) + self.socket.bind(self.config.ipc_socket) + self.socket.listen(1) + + elif self.mode == 'agent': + # Connect to the socket + self.conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.conn.settimeout(self.config.ipc_timeout) + + else: + raise RuntimeError + + def connect(self): + """ + Establish a connection to a socket + """ + self.conn.connect(self.config.ipc_socket) + return self.conn + + def accept(self): + """ + Accept a connection + """ + (conn, _) = self.socket.accept() + conn.settimeout(self.config.ipc_timeout) + return conn + + def send(self, conn, msg): + msg_bytes = msg.encode("utf-8") + msg_len = len(msg_bytes) + msg_len_bytes = msg_len.to_bytes(4, byteorder='big') + conn.sendall(msg_len_bytes + msg_bytes) + + def recv(self, conn, timeout=-1): + # Read at least the length + buffer = [] + msg_len = -4 + while msg_len < 0: + buffer += conn.recv(1024) + msg_len = len(buffer) - 4 + + msg_len_bytes = buffer[:4] + msg_len = int.from_bytes(msg_len_bytes, byteorder='big') + + while len(buffer) < msg_len + 4: + buffer += conn.recv(1024) + + return bytes(buffer[4:]).decode("utf-8") + + +class Agent: + """ + The agent side of the connector + """ + @staticmethod + def run(config, key, args): + # Connect to the socket + ipc = IPC(config, 'agent') + try: + conn = ipc.connect() + + # Send a request + ipc.send(conn, "{},{}".format(key, ",".join(args))) + + # Wait for a response + res = ipc.recv(conn) + + except: + print("IPC error") + sys.exit(1) + + # Output the response + print(res) + + +class Server: + """ + The server side of the connector + """ + @staticmethod + def run(config): + import psycopg2 + import psycopg2.extras + import queue + import os + import signal + import json + + # Set up the signal handler + signal.signal(signal.SIGINT, signal_handler) + + # Create reqest queue + req_queue = queue.Queue(config.request_queue_size) + + # Spawn worker threads + workers = [None] * config.worker_count + for i in range(config.worker_count): + workers[i] = Worker(config, req_queue) + workers[i].start() + + # Listen on ipc socket + ipc = IPC(config, 'server') + + while True: + # Wait for a request connection + try: + conn = ipc.accept() + except socket.timeout: + conn = None + + # See if we should exit + if not running: + break + + # If we just timed out waiting for a request, go back to waiting + if conn is None: + continue + + # Get the request string (csv) + req_csv = ipc.recv(conn) + + # Receive ipc request (csv) + try: + (key, args_csv) = req_csv.split(',', 1) + args_dict = {} + if args_csv != "": + for (k, v) in [a.split('=', 1) for a in args_csv.split(',')]: + args_dict[k] = v + except socket.timeout: + print("IPC communication timeout receiving request") + conn.close() + continue + except Exception: + print("Received invalid request: '{}'".format(req_csv)) + conn.close() + continue + + # Create request object + req = Request(key, args_dict) + + # Queue the request + try: + req_queue.put(req, timeout=config.request_queue_timeout) + except: + print("Failed to queue request") + req.set_result("ERROR: Queue timeout") + continue + + # Spawn a thread to wait for the result + r = Responder(config, ipc, conn, req) + r.start() + + # Join worker threads + for worker in workers: + worker.join() + + print("Good bye") + + +class Worker(threading.Thread): + """ + Worker thread that processes requests (ie: queries the database) + """ + def __init__(self, config, queue): + super(Worker, self).__init__() + + self.config = config + self.queue = queue + + def run(self): + while True: + # Wait for a request + try: + req = self.queue.get(timeout=1) + except queue.Empty: + req = None + + # Check if we're supposed to exit + if not running: + break + + # If we got here because we waited too long for a request, go back to waiting + if req is None: + continue + + # Find the requested metrtic + try: + metric = self.config.metrics[req.key] + except KeyError: + req.set_result("ERROR: Unknown key '{}'".format(req.key)) + continue + + # Get the DB version + try: + pg_version = self.config.get_pg_version() + except Exception as e: + req.set_result("Failed to retrieve database version") + print("Error: {}".format(e)) + continue + + # Get the query to use + try: + mv = metric.get_version(pg_version) + except UnsupportedMetricVersionError: + req.set_result("Unsupported PosgreSQL version for metric") + continue + + # Query the database + try: + dbname = req.args['db'] + except KeyError: + dbname = None + + try: + db = DB(self.config, dbname) + res = db.query(mv.get_sql(req.args)) + db.close() + except Exception as e: + # Make sure the database connectin is closed (ie: if the query timed out) + try: + db.close() + except: + pass + req.set_result("Failed to query database") + print("Error: {}".format(e)) + continue + + # Set the result on the request + if len(res) == 0: + req.set_result("Empty result set") + elif metric.type == 'value': + req.set_result("{}".format(list(res[0].values())[0])) + elif metric.type == 'row': + req.set_result(json.dumps(res[0])) + elif metric.type == 'column': + req.set_result(json.dumps([list(r.values())[0] for r in res])) + else: + req.set_result(json.dumps(res)) + + try: + db.close() + except Exception as e: + print("Failed to close database connection") + print("Error: {}".format(e)) + +class Responder(threading.Thread): + """ + Thread responsible for replying to requests + """ + def __init__(self, config, ipc, conn, req): + super(Responder, self).__init__() + + # Wait for a result + result = req.get_result() + + # Send the result back to the client + ipc.send(conn, result) + +def main(): + parser = argparse.ArgumentParser( + prog='pgmon', + description='A briidge between monitoring tools and PostgreSQL') + + # General options + parser.add_argument('-c', '--config', default='pgmon.cfg') + parser.add_argument('-v', '--verbose', action='store_true') + + # Operational mode + parser.add_argument('-m', '--mode', choices=['agent', 'server'], required=True) + + # Agent options + parser.add_argument('-k', '--key') + parser.add_argument('-a', '--args', nargs='*', default=[]) + + args = parser.parse_args() + + if args.mode == 'agent': + config = Config(args.config, read_metrics = False) + Agent.run(config, args.key, args.args) + + else: + config = Config(args.config) + Server.run(config) + +if __name__ == '__main__': + main() + +## +# Tests +## +class TestConfig: + pass + +class TestDB: + pass + +class TestRequest: + def test_request_creation(self): + # Create result with no args + req1 = Request('foo', {}) + assert req1.key == 'foo' + assert len(req1.args) == 0 + assert req1.complete.locked() + + # Create result with args + req2 = Request('foo', {'arg1': 'value1', 'arg2': 'value2'}) + assert req2.key == 'foo' + assert len(req2.args) == 2 + assert req2.complete.locked() + + def test_request_lock(self): + req1 = Request('foo', {}) + assert req1.complete.locked() + req1.set_result('blah') + assert not req1.complete.locked() + assert 'blah' == req1.get_result() + + + +class TestMetric: + def test_metric_creation(self): + # Test basic creation + m1 = Metric('foo', 'value') + assert m1.name == 'foo' + assert len(m1.versions) == 0 + assert m1.type == 'value' + assert m1.cached is None + assert m1.cached_version is None + + def test_metric_add_version(self): + # Test adding versions + m1 = Metric('foo', 'value') + assert len(m1.versions) == 0 + m1.add_version(0, 'default') + assert len(m1.versions) == 1 + m1.add_version(120003, 'v12.3') + assert len(m1.versions) == 2 + + # Make sure added versions are correct + assert m1.versions[0].sql == 'default' + assert m1.versions[120003].sql == 'v12.3' + + def test_metric_get_version(self): + # Test retrieving metric versions + m1 = Metric('foo', 'value') + m1.add_version(100000, 'v10.0') + m1.add_version(120000, 'v12.0') + + # Make sure cache is initially empty + assert m1.cached is None + assert m1.cached_version is None + + assert m1.get_version(110003).sql == 'v10.0' + + # Make sure cache is set + assert m1.cached is not None + assert m1.cached_version is 110003 + + # Make sure returned value changes with version + assert m1.get_version(120000).sql == 'v12.0' + assert m1.get_version(150005).sql == 'v12.0' + + # Make sure an error is thrown when no version matches + with pytest.raises(UnsupportedMetricVersionError): + m1.get_version(90603) + + # Add a default version + m1.add_version(0, 'default') + assert m1.get_version(90603).sql == 'default' + assert m1.get_version(110003).sql == 'v10.0' + assert m1.get_version(120000).sql == 'v12.0' + assert m1.get_version(150005).sql == 'v12.0' + +class TestMetricVersion: + def test_metric_version_creation(self): + mv1 = MetricVersion('test') + assert mv1.sql == 'test' + + def test_metric_version_templating(self): + mv1 = MetricVersion('foo') + assert mv1.get_sql({}) == 'foo' + + mv2 = MetricVersion('foo {a1} {a3} {a2}') + assert mv2.get_sql({'a1': 'bar', 'a2': 'blah blah blah', 'a3': 'baz'}) == 'foo bar baz blah blah blah' + + +class TestIPC: + pass + +class TestAgent: + pass + +class TestServer: + pass + +class TestWorker: + pass + +class TestResponder: + pass