Implement a startup script which runs once on microservice start.

This commit is contained in:
Seán Healy
2026-05-08 20:45:49 +01:00
parent 13b82d66c6
commit c823a73bc4
6 changed files with 246 additions and 21 deletions

View File

@@ -9,6 +9,9 @@ assumptions about a microservice:
1. A microservice listens to incoming requests on its own dedicated and 1. A microservice listens to incoming requests on its own dedicated and
singular queue (RabbitMQ). singular queue (RabbitMQ).
2. Incoming requests are in the form of a 64-bit unsigned integer (`u64`). 2. Incoming requests are in the form of a 64-bit unsigned integer (`u64`).
2. Microservices can run optional startup work via a `startup` function, which
takes three arguments: a `read_file` function, a `write_file` function, and
a database ORM `connection`.
2. Microservices process requests via a `process` function, which takes four 2. Microservices process requests via a `process` function, which takes four
arguments: the incoming request (`u64`), a `read_file` function, a arguments: the incoming request (`u64`), a `read_file` function, a
`write_file` function, and a database ORM `connection`. `write_file` function, and a database ORM `connection`.
@@ -85,6 +88,15 @@ from slingshot_microservice.typing import ReadFileFn, WriteFileFn
from slingshot_microservice import Microservice from slingshot_microservice import Microservice
def startup(
read_file: ReadFileFn,
write_file: WriteFileFn,
connection: Connection,
) -> Generator[tuple[int, bool | int | str], None, None]:
if False:
yield (0, True)
def process( def process(
request: int, request: int,
read_file: ReadFileFn, read_file: ReadFileFn,
@@ -100,7 +112,12 @@ def process(
yield (request, True) yield (request, True)
microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", process) microservice = Microservice(
"simple-py-microservice",
"sys-map.slingshot.cv",
startup,
process,
)
microservice.start() microservice.start()
``` ```
@@ -113,7 +130,8 @@ editors and type-checkers:
|---|---| |---|---|
| `ReadFileFn` | Callable returned by `read_file(key, id)` behaves like `BinaryIO` | | `ReadFileFn` | Callable returned by `read_file(key, id)` behaves like `BinaryIO` |
| `WriteFileFn` | Callable returned by `write_file(key, id)` behaves like `BinaryIO` | | `WriteFileFn` | Callable returned by `write_file(key, id)` behaves like `BinaryIO` |
| `ProcessFn` | The generator signature expected by `Microservice` with `(request, read_file, write_file, connection)` | | `StartupFn` | The generator signature expected by `Microservice` startup hook with `(read_file, write_file, connection)` |
| `ProcessFn` | The generator signature expected by `Microservice` process callback with `(request, read_file, write_file, connection)` |
| `CaseVariable` | `bool \| int \| str` valid case variable types | | `CaseVariable` | `bool \| int \| str` valid case variable types |
### Publishing Wheels ### Publishing Wheels
@@ -154,11 +172,20 @@ fn process(
Ok(vec![(request, "case_a".to_string())]) Ok(vec![(request, "case_a".to_string())])
} }
fn startup(
read_file: &ReadFileFn,
write_file: &WriteFileFn,
connection: &mut PgConnection,
) -> Result<Vec<(u64, String)>, AnyError> {
Ok(Vec::new())
}
fn main() { fn main() {
// Create a new microservice instance with the processing function // Create a new microservice instance with startup and process functions
let microservice = Microservice::new( let microservice = Microservice::new(
"simple-microservice", "simple-microservice",
"sys-map.example.com", "sys-map.example.com",
startup,
process process
); );
@@ -220,7 +247,11 @@ actual secrets with `pass show <key>` before constructing the S3 client.
When the microservice first starts up, it makes a request to the configuration When the microservice first starts up, it makes a request to the configuration
service to get the queue metadata. Then it starts to listen to the inbound service to get the queue metadata. Then it starts to listen to the inbound
queue. Inbound requests are processed by the user-programmed `process` queue. Before consuming inbound requests, the user-programmed `startup`
function runs once with `(read_file, write_file, connection)` and may return
zero or more `(result_id, case_variable)` tuples for normal outbound routing.
Inbound requests are then processed by the user-programmed `process`
function, which is called with `(request, read_file, write_file, connection)` function, which is called with `(request, read_file, write_file, connection)`
and returns a set of tuples of the form `(result_id, case_variable)`. and returns a set of tuples of the form `(result_id, case_variable)`.
@@ -247,6 +278,10 @@ The output queue routing step looks like this:
Peudocode: Peudocode:
``` ```
for each (result_id, case_variable) in startup(read_file, write_file, connection):
for each outbound_queue in config.out[case_variable]:
send result_id to outbound_queue
for each (result_id, case_variable) in process(request, read_file, write_file, connection): for each (result_id, case_variable) in process(request, read_file, write_file, connection):
for each outbound_queue in config.out[case_variable]: for each outbound_queue in config.out[case_variable]:
send result_id to outbound_queue send result_id to outbound_queue

View File

@@ -3,6 +3,15 @@ from slingshot_microservice import Microservice
from typing import Generator from typing import Generator
from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Connection
def startup(
read_file: ReadFileFn,
write_file: WriteFileFn,
_connection: Connection,
) -> Generator[tuple[int, bool | int | str], None, None]:
if False:
yield (0, True)
def process( def process(
request: int, request: int,
read_file: ReadFileFn, read_file: ReadFileFn,
@@ -16,5 +25,5 @@ def process(
yield (request, True) yield (request, True)
microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", process) microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", startup, process)
microservice.start() microservice.start()

View File

@@ -18,8 +18,21 @@ fn process(
Ok(vec![(request, "case_a".to_string())]) Ok(vec![(request, "case_a".to_string())])
} }
fn startup(
_read_file: &ReadFileFn,
_write_file: &WriteFileFn,
_connection: &mut PgConnection,
) -> Result<Vec<(u64, String)>, AnyError> {
Ok(Vec::new())
}
fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let microservice = Microservice::new("simple-microservice", "sys-map.slingshot.cv", process); let microservice = Microservice::new(
"simple-microservice",
"sys-map.slingshot.cv",
startup,
process,
);
microservice.start()?; microservice.start()?;
Ok(()) Ok(())

View File

@@ -4,7 +4,7 @@ build-backend = "maturin"
[project] [project]
name = "slingshot-microservice" name = "slingshot-microservice"
version = "0.1.5" version = "0.1.6"
description = "Opinionated Rust framework for queue-driven microservices" description = "Opinionated Rust framework for queue-driven microservices"
license = { text = "MIT" } license = { text = "MIT" }
requires-python = ">=3.8" requires-python = ">=3.8"

View File

@@ -33,10 +33,20 @@ class ProcessFn(Protocol):
) -> Generator[tuple[int, CaseVariable], None, None]: ... ) -> Generator[tuple[int, CaseVariable], None, None]: ...
class StartupFn(Protocol):
def __call__(
self,
read_file: ReadFileFn,
write_file: WriteFileFn,
connection: SqlAlchemyConnection,
) -> Generator[tuple[int, CaseVariable], None, None]: ...
__all__ = [ __all__ = [
"CaseVariable", "CaseVariable",
"ReadFileFn", "ReadFileFn",
"WriteHandle", "WriteHandle",
"WriteFileFn", "WriteFileFn",
"ProcessFn", "ProcessFn",
"StartupFn",
] ]

View File

@@ -49,6 +49,12 @@ type ProcessFn = dyn Fn(u64, Arc<ReadFileFn>, Arc<WriteFileFn>) -> Result<Vec<(u
+ Sync + Sync
+ 'static; + 'static;
type StartupFn =
dyn Fn(Arc<ReadFileFn>, Arc<WriteFileFn>) -> Result<Vec<(u64, CaseKey)>, AnyError>
+ Send
+ Sync
+ 'static;
static REQUEST_FILE_CONTEXT_COUNTER: AtomicU64 = AtomicU64::new(1); static REQUEST_FILE_CONTEXT_COUNTER: AtomicU64 = AtomicU64::new(1);
static DATABASE_URL_CACHE: OnceLock<String> = OnceLock::new(); static DATABASE_URL_CACHE: OnceLock<String> = OnceLock::new();
static DATABASE_POOL_CACHE: OnceLock<Pool<ConnectionManager<PgConnection>>> = OnceLock::new(); static DATABASE_POOL_CACHE: OnceLock<Pool<ConnectionManager<PgConnection>>> = OnceLock::new();
@@ -90,6 +96,7 @@ enum CaseKey {
pub struct Microservice { pub struct Microservice {
name: String, name: String,
config_host: String, config_host: String,
startup: Arc<StartupFn>,
process: Arc<ProcessFn>, process: Arc<ProcessFn>,
} }
@@ -98,32 +105,60 @@ impl Microservice {
fn new_case_key( fn new_case_key(
name: impl Into<String>, name: impl Into<String>,
config_host: impl Into<String>, config_host: impl Into<String>,
startup: Arc<StartupFn>,
process: Arc<ProcessFn>, process: Arc<ProcessFn>,
) -> Self { ) -> Self {
init_tracing(); init_tracing();
Self { Self {
name: name.into(), name: name.into(),
config_host: config_host.into(), config_host: config_host.into(),
startup,
process, process,
} }
} }
/// Create a new microservice runtime. /// Create a new microservice runtime.
/// ///
/// `process` accepts an inbound request ID, a `read_file` function, and a /// `startup` runs once when the service starts, before queue consumption,
/// `write_file` function, and a mutable PostgreSQL connection, and then /// and receives `read_file`, `write_file`, and a mutable PostgreSQL
/// returns a list of /// connection.
///
/// `process` accepts an inbound request ID, `read_file`, `write_file`, and
/// a mutable PostgreSQL connection, and then returns a list of
/// `(result_id, case_variable)` tuples. Case variables can be any /// `(result_id, case_variable)` tuples. Case variables can be any
/// serializable primitive, such as `String`, `bool`, or integers. /// serializable primitive, such as `String`, `bool`, or integers.
pub fn new<F, C>(name: impl Into<String>, config_host: impl Into<String>, process: F) -> Self pub fn new<F, S, C>(
name: impl Into<String>,
config_host: impl Into<String>,
startup: S,
process: F,
) -> Self
where where
F: Fn(u64, &ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result<Vec<(u64, C)>, AnyError> F: Fn(u64, &ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result<Vec<(u64, C)>, AnyError>
+ Send + Send
+ Sync + Sync
+ 'static, + 'static,
S: Fn(&ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result<Vec<(u64, C)>, AnyError>
+ Send
+ Sync
+ 'static,
C: Serialize + 'static, C: Serialize + 'static,
{ {
init_tracing(); init_tracing();
let startup_wrapper = move |
read_file: Arc<ReadFileFn>,
write_file: Arc<WriteFileFn>,
| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let mut connection = establish_pg_connection().map_err(|e| {
format!(
"failed to checkout PostgreSQL connection from pool for startup callback: {}",
e
)
})?;
let outputs = startup(read_file.as_ref(), write_file.as_ref(), &mut connection)?;
map_outputs_with_case_key(outputs)
};
let process_wrapper = move | let process_wrapper = move |
request: u64, request: u64,
read_file: Arc<ReadFileFn>, read_file: Arc<ReadFileFn>,
@@ -136,19 +171,13 @@ impl Microservice {
) )
})?; })?;
let outputs = process(request, read_file.as_ref(), write_file.as_ref(), &mut connection)?; let outputs = process(request, read_file.as_ref(), write_file.as_ref(), &mut connection)?;
let mut mapped = Vec::with_capacity(outputs.len()); map_outputs_with_case_key(outputs)
for (id, case) in outputs {
let value = serde_json::to_value(case)
.map_err(|e| format!("case variable must be serializable to JSON: {}", e))?;
mapped.push((id, case_key_from_value(&value)?));
}
Ok(mapped)
}; };
Self { Self {
name: name.into(), name: name.into(),
config_host: config_host.into(), config_host: config_host.into(),
startup: Arc::new(startup_wrapper),
process: Arc::new(process_wrapper), process: Arc::new(process_wrapper),
} }
} }
@@ -190,6 +219,61 @@ impl Microservice {
declare_queues(&channel, &inbound_queue, &route_map).await?; declare_queues(&channel, &inbound_queue, &route_map).await?;
let startup_request_id = 0_u64;
let startup_context = Arc::new(Mutex::new(RequestFileContext::new(startup_request_id)?));
let startup_read_context = Arc::clone(&startup_context);
let startup_write_context = Arc::clone(&startup_context);
let startup_s3_read_client = Arc::clone(&s3_client);
let startup_read_bucket_cache = Arc::clone(&bucket_name_cache);
let startup_write_bucket_cache = Arc::clone(&bucket_name_cache);
let startup_config_host = self.config_host.clone();
let startup_microservice_name = self.name.clone();
let startup_read_config_host = startup_config_host.clone();
let startup_read_microservice_name = startup_microservice_name.clone();
let startup_read_file: Arc<ReadFileFn> = Arc::new(move |key: &str, id: u64| -> Result<ReadFile, AnyError> {
let bucket = resolve_bucket_name(
&startup_read_config_host,
&startup_read_microservice_name,
&startup_read_bucket_cache,
key,
)?;
let mut guard = startup_read_context
.lock()
.map_err(|e| format!("file context lock poisoned for startup read_file: {}", e))?;
guard.read_file(startup_s3_read_client.as_ref(), &bucket, id)
});
let startup_write_file: Arc<WriteFileFn> = Arc::new(move |key: &str, id: u64| -> Result<File, AnyError> {
let bucket = resolve_bucket_name(
&startup_config_host,
&startup_microservice_name,
&startup_write_bucket_cache,
key,
)?;
let mut guard = startup_write_context
.lock()
.map_err(|e| format!("file context lock poisoned for startup write_file: {}", e))?;
guard.write_file(&bucket, id)
});
let startup_outputs = (self.startup)(
Arc::clone(&startup_read_file),
Arc::clone(&startup_write_file),
)
.map_err(|e| format!("startup callback failed: {}", e))?;
{
let mut guard = startup_context
.lock()
.map_err(|e| format!("file context lock poisoned for startup finalize: {}", e))?;
guard
.finalize(s3_client.as_ref())
.map_err(|e| format!("startup finalize/upload failed: {}", e))?
}
publish_outputs(&channel, startup_outputs, &route_map)
.await
.map_err(|e| format!("startup publish failed: {}", e))?;
let mut consumer = channel let mut consumer = channel
.basic_consume( .basic_consume(
&inbound_queue, &inbound_queue,
@@ -700,6 +784,17 @@ fn case_key_from_value(case_value: &Value) -> Result<CaseKey, AnyError> {
} }
} }
fn map_outputs_with_case_key<C: Serialize>(outputs: Vec<(u64, C)>) -> Result<Vec<(u64, CaseKey)>, AnyError> {
let mut mapped = Vec::with_capacity(outputs.len());
for (id, case) in outputs {
let value = serde_json::to_value(case)
.map_err(|e| format!("case variable must be serializable to JSON: {}", e))?;
mapped.push((id, case_key_from_value(&value)?));
}
Ok(mapped)
}
async fn declare_queues( async fn declare_queues(
channel: &Channel, channel: &Channel,
inbound_queue: &str, inbound_queue: &str,
@@ -955,6 +1050,7 @@ fn create_python_sqlalchemy_engine() -> Result<Py<PyAny>, AnyError> {
struct PyMicroservice { struct PyMicroservice {
name: String, name: String,
config_host: String, config_host: String,
startup: Py<PyAny>,
process: Py<PyAny>, process: Py<PyAny>,
} }
@@ -962,8 +1058,12 @@ struct PyMicroservice {
#[pymethods] #[pymethods]
impl PyMicroservice { impl PyMicroservice {
#[new] #[new]
fn new(name: String, config_host: String, process: Py<PyAny>) -> PyResult<Self> { fn new(name: String, config_host: String, startup: Py<PyAny>, process: Py<PyAny>) -> PyResult<Self> {
Python::with_gil(|py| { Python::with_gil(|py| {
if !startup.bind(py).is_callable() {
return Err(PyTypeError::new_err("startup must be callable"));
}
if !process.bind(py).is_callable() { if !process.bind(py).is_callable() {
return Err(PyTypeError::new_err("process must be callable")); return Err(PyTypeError::new_err("process must be callable"));
} }
@@ -971,19 +1071,26 @@ impl PyMicroservice {
Ok(Self { Ok(Self {
name, name,
config_host, config_host,
startup,
process, process,
}) })
}) })
} }
fn start(&self) -> PyResult<()> { fn start(&self) -> PyResult<()> {
let startup = Python::with_gil(|py| self.startup.clone_ref(py));
let process = Python::with_gil(|py| self.process.clone_ref(py)); let process = Python::with_gil(|py| self.process.clone_ref(py));
let engine = create_python_sqlalchemy_engine().map_err(any_error_to_py)?; let engine = create_python_sqlalchemy_engine().map_err(any_error_to_py)?;
let startup_engine = Python::with_gil(|py| engine.clone_ref(py));
let process_engine = Python::with_gil(|py| engine.clone_ref(py));
let microservice = Microservice::new_case_key( let microservice = Microservice::new_case_key(
self.name.clone(), self.name.clone(),
self.config_host.clone(), self.config_host.clone(),
Arc::new(move |read_file, write_file| {
run_python_startup(&startup, &startup_engine, read_file, write_file)
}),
Arc::new(move |request, read_file, write_file| { Arc::new(move |request, read_file, write_file| {
run_python_process(&process, &engine, request, read_file, write_file) run_python_process(&process, &process_engine, request, read_file, write_file)
}), }),
); );
@@ -991,6 +1098,57 @@ impl PyMicroservice {
} }
} }
#[cfg(feature = "python")]
fn run_python_startup(
startup: &Py<PyAny>,
engine: &Py<PyAny>,
read_file: Arc<ReadFileFn>,
write_file: Arc<WriteFileFn>,
) -> Result<Vec<(u64, CaseKey)>, AnyError> {
Python::with_gil(|py| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let py_read = Py::new(py, PyReadFileFn { inner: read_file })
.map_err(|e| format!("failed to build Python ReadFileFn wrapper: {}", e))?;
let py_write = Py::new(py, PyWriteFileFn { inner: write_file })
.map_err(|e| format!("failed to build Python WriteFileFn wrapper: {}", e))?;
let connection = engine
.bind(py)
.call_method0("connect")
.map_err(|e| format!("failed to open SQLAlchemy connection: {}", e))?;
let outputs_result = (|| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let returned = startup
.call1(py, (py_read, py_write, connection.clone()))
.map_err(|e| format!("Python startup callback failed: {}", e))?;
let iter = returned
.bind(py)
.try_iter()
.map_err(|e| format!("startup return value must be iterable: {}", e))?;
let mut outputs = Vec::new();
for item in iter {
let item = item.map_err(|e| format!("failed to iterate startup outputs: {}", e))?;
let (id, case_obj): (u64, Py<PyAny>) = item
.extract()
.map_err(|e| format!("each startup output must be a tuple (int, case): {}", e))?;
let case = case_key_from_py_value(case_obj.bind(py))
.map_err(|e| format!("invalid case variable: {}", e))?;
outputs.push((id, case));
}
Ok(outputs)
})();
let close_result = connection.call_method0("close");
if let Err(err) = close_result {
return Err(format!("failed to close SQLAlchemy connection: {}", err).into());
}
outputs_result
})
}
#[cfg(feature = "python")] #[cfg(feature = "python")]
#[pymodule] #[pymodule]
fn _native(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { fn _native(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {