From 13b82d66c661fc2cb2a27d8c50cc6924bcb48ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Se=C3=A1n=20Healy?= Date: Fri, 1 May 2026 21:29:04 +0100 Subject: [PATCH] Do some connection pooling. This prevents exhaustion of connections and also has some performance benefits. --- Cargo.lock | 21 ++++++++++++++ Cargo.toml | 2 +- pyproject.toml | 2 +- src/lib.rs | 76 +++++++++++++++++++++++++++++++++++++++----------- 4 files changed, 83 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 36f2cbc..6da4d48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1185,6 +1185,7 @@ dependencies = [ "downcast-rs", "itoa", "pq-sys", + "r2d2", ] [[package]] @@ -2652,6 +2653,17 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r2d2" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" +dependencies = [ + "log", + "parking_lot", + "scheduled-thread-pool", +] + [[package]] name = "rand" version = "0.9.4" @@ -2994,6 +3006,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot", +] + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 7021875..12c9c76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ aws-config = "1" aws-sdk-s3 = "1" futures-util = "0.3" lapin = "2" -diesel = { version = "2", features = ["postgres"] } +diesel = { version = "2", features = ["postgres", "r2d2"] } reqwest = { version = "0.12", features = ["blocking", "json", "rustls-tls"], default-features = false } serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/pyproject.toml b/pyproject.toml index 6830bf9..55781d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "slingshot-microservice" -version = "0.1.4" +version = "0.1.5" description = "Opinionated Rust framework for queue-driven microservices" license = { text = "MIT" } requires-python = ">=3.8" diff --git a/src/lib.rs b/src/lib.rs index 3d97764..8a334e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,13 +8,14 @@ use std::path::{Path, PathBuf}; use std::process::Command; use std::sync::Arc; use std::sync::Mutex; +use std::sync::OnceLock; use std::sync::atomic::{AtomicU64, Ordering}; use aws_config::BehaviorVersion; use aws_sdk_s3::Client; use aws_sdk_s3::config::{Credentials, Region}; use aws_sdk_s3::primitives::ByteStream; -use diesel::Connection as DieselConnection; +use diesel::r2d2::{ConnectionManager, Pool, PooledConnection}; use diesel::PgConnection; use futures_util::StreamExt; use lapin::options::{ @@ -49,6 +50,8 @@ type ProcessFn = dyn Fn(u64, Arc, Arc) -> Result = OnceLock::new(); +static DATABASE_POOL_CACHE: OnceLock>> = OnceLock::new(); fn init_tracing() { let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); @@ -128,7 +131,7 @@ impl Microservice { | -> Result, AnyError> { let mut connection = establish_pg_connection().map_err(|e| { format!( - "failed to establish PostgreSQL connection for request {}: {}", + "failed to checkout PostgreSQL connection from pool for request {}: {}", request, e ) })?; @@ -601,10 +604,38 @@ fn resolve_password_from_pass(pass_key: &str) -> Result { Ok(password) } -fn establish_pg_connection() -> Result { - let database_url = fetch_database_url_from_sys_map()?; - PgConnection::establish(&database_url) - .map_err(|e| format!("failed to connect to PostgreSQL using sys-map DB config: {}", e).into()) +fn establish_pg_connection() -> Result>, AnyError> { + let pool = cached_pg_pool()?; + pool + .get() + .map_err(|e| format!("failed to get PostgreSQL connection from pool: {}", e).into()) +} + +fn cached_database_url() -> Result { + if let Some(url) = DATABASE_URL_CACHE.get() { + return Ok(url.clone()); + } + + let url = fetch_database_url_from_sys_map()?; + let _ = DATABASE_URL_CACHE.set(url.clone()); + Ok(url) +} + +fn cached_pg_pool() -> Result<&'static Pool>, AnyError> { + if let Some(pool) = DATABASE_POOL_CACHE.get() { + return Ok(pool); + } + + let database_url = cached_database_url()?; + let manager = ConnectionManager::::new(database_url); + let pool = Pool::builder() + .build(manager) + .map_err(|e| format!("failed to build PostgreSQL connection pool: {}", e))?; + + let _ = DATABASE_POOL_CACHE.set(pool); + DATABASE_POOL_CACHE + .get() + .ok_or_else(|| "failed to initialize PostgreSQL connection pool".into()) } fn single_value(values: &[T], field_name: &str) -> Result { @@ -854,25 +885,18 @@ impl PyWriteFileFn { #[cfg(feature = "python")] fn run_python_process( process: &Py, + engine: &Py, request: u64, read_file: Arc, write_file: Arc, ) -> Result, AnyError> { - let database_url = fetch_database_url_from_sys_map()?; - Python::with_gil(|py| -> Result, AnyError> { let py_read = Py::new(py, PyReadFileFn { inner: read_file }) .map_err(|e| format!("failed to build Python ReadFileFn wrapper: {}", e))?; let py_write = Py::new(py, PyWriteFileFn { inner: write_file }) .map_err(|e| format!("failed to build Python WriteFileFn wrapper: {}", e))?; - let sqlalchemy = py - .import("sqlalchemy") - .map_err(|e| format!("failed to import sqlalchemy: {}", e))?; - let engine = sqlalchemy - .getattr("create_engine") - .and_then(|f| f.call1((database_url.as_str(),))) - .map_err(|e| format!("failed to create SQLAlchemy engine: {}", e))?; let connection = engine + .bind(py) .call_method0("connect") .map_err(|e| format!("failed to open SQLAlchemy connection: {}", e))?; @@ -909,6 +933,23 @@ fn run_python_process( }) } +#[cfg(feature = "python")] +fn create_python_sqlalchemy_engine() -> Result, AnyError> { + let database_url = cached_database_url()?; + + Python::with_gil(|py| { + let sqlalchemy = py + .import("sqlalchemy") + .map_err(|e| format!("failed to import sqlalchemy: {}", e))?; + let engine = sqlalchemy + .getattr("create_engine") + .and_then(|f| f.call1((database_url.as_str(),))) + .map_err(|e| format!("failed to create SQLAlchemy engine: {}", e))?; + + Ok(engine.unbind()) + }) +} + #[cfg(feature = "python")] #[pyclass(name = "Microservice")] struct PyMicroservice { @@ -937,10 +978,13 @@ impl PyMicroservice { fn start(&self) -> PyResult<()> { let process = Python::with_gil(|py| self.process.clone_ref(py)); + let engine = create_python_sqlalchemy_engine().map_err(any_error_to_py)?; let microservice = Microservice::new_case_key( self.name.clone(), self.config_host.clone(), - Arc::new(move |request, read_file, write_file| run_python_process(&process, request, read_file, write_file)), + Arc::new(move |request, read_file, write_file| { + run_python_process(&process, &engine, request, read_file, write_file) + }), ); microservice.start().map_err(any_error_to_py)