diff --git a/README.md b/README.md index 0407d71..fdf69cd 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ assumptions about a microservice: 1. A microservice listens to incoming requests on its own dedicated and singular queue (RabbitMQ). 2. Incoming requests are in the form of a 64-bit unsigned integer (`u64`). +2. Microservices can run optional startup work via a `startup` function, which + takes three arguments: a `read_file` function, a `write_file` function, and + a database ORM `connection`. 2. Microservices process requests via a `process` function, which takes four arguments: the incoming request (`u64`), a `read_file` function, a `write_file` function, and a database ORM `connection`. @@ -85,6 +88,15 @@ from slingshot_microservice.typing import ReadFileFn, WriteFileFn from slingshot_microservice import Microservice +def startup( + read_file: ReadFileFn, + write_file: WriteFileFn, + connection: Connection, +) -> Generator[tuple[int, bool | int | str], None, None]: + if False: + yield (0, True) + + def process( request: int, read_file: ReadFileFn, @@ -100,7 +112,12 @@ def process( yield (request, True) -microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", process) +microservice = Microservice( + "simple-py-microservice", + "sys-map.slingshot.cv", + startup, + process, +) microservice.start() ``` @@ -113,7 +130,8 @@ editors and type-checkers: |---|---| | `ReadFileFn` | Callable returned by `read_file(key, id)` – behaves like `BinaryIO` | | `WriteFileFn` | Callable returned by `write_file(key, id)` – behaves like `BinaryIO` | -| `ProcessFn` | The generator signature expected by `Microservice` with `(request, read_file, write_file, connection)` | +| `StartupFn` | The generator signature expected by `Microservice` startup hook with `(read_file, write_file, connection)` | +| `ProcessFn` | The generator signature expected by `Microservice` process callback with `(request, read_file, write_file, connection)` | | `CaseVariable` | `bool \| int \| str` – valid case variable types | ### Publishing Wheels @@ -154,11 +172,20 @@ fn process( Ok(vec![(request, "case_a".to_string())]) } +fn startup( + read_file: &ReadFileFn, + write_file: &WriteFileFn, + connection: &mut PgConnection, +) -> Result, AnyError> { + Ok(Vec::new()) +} + fn main() { - // Create a new microservice instance with the processing function + // Create a new microservice instance with startup and process functions let microservice = Microservice::new( "simple-microservice", "sys-map.example.com", + startup, process ); @@ -220,7 +247,11 @@ actual secrets with `pass show ` before constructing the S3 client. When the microservice first starts up, it makes a request to the configuration service to get the queue metadata. Then it starts to listen to the inbound -queue. Inbound requests are processed by the user-programmed `process` +queue. Before consuming inbound requests, the user-programmed `startup` +function runs once with `(read_file, write_file, connection)` and may return +zero or more `(result_id, case_variable)` tuples for normal outbound routing. + +Inbound requests are then processed by the user-programmed `process` function, which is called with `(request, read_file, write_file, connection)` and returns a set of tuples of the form `(result_id, case_variable)`. @@ -247,6 +278,10 @@ The output queue routing step looks like this: Peudocode: ``` +for each (result_id, case_variable) in startup(read_file, write_file, connection): + for each outbound_queue in config.out[case_variable]: + send result_id to outbound_queue + for each (result_id, case_variable) in process(request, read_file, write_file, connection): for each outbound_queue in config.out[case_variable]: send result_id to outbound_queue diff --git a/examples/py_simple.py b/examples/py_simple.py index eef1bc6..b96e3d2 100644 --- a/examples/py_simple.py +++ b/examples/py_simple.py @@ -3,6 +3,15 @@ from slingshot_microservice import Microservice from typing import Generator from sqlalchemy.engine.base import Connection + +def startup( + read_file: ReadFileFn, + write_file: WriteFileFn, + _connection: Connection, +) -> Generator[tuple[int, bool | int | str], None, None]: + if False: + yield (0, True) + def process( request: int, read_file: ReadFileFn, @@ -16,5 +25,5 @@ def process( yield (request, True) -microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", process) +microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", startup, process) microservice.start() diff --git a/examples/simple.rs b/examples/simple.rs index 5de7b2c..80acefd 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -18,8 +18,21 @@ fn process( Ok(vec![(request, "case_a".to_string())]) } +fn startup( + _read_file: &ReadFileFn, + _write_file: &WriteFileFn, + _connection: &mut PgConnection, +) -> Result, AnyError> { + Ok(Vec::new()) +} + fn main() -> Result<(), Box> { - let microservice = Microservice::new("simple-microservice", "sys-map.slingshot.cv", process); + let microservice = Microservice::new( + "simple-microservice", + "sys-map.slingshot.cv", + startup, + process, + ); microservice.start()?; Ok(()) diff --git a/pyproject.toml b/pyproject.toml index 55781d7..01fdf20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "slingshot-microservice" -version = "0.1.5" +version = "0.1.6" description = "Opinionated Rust framework for queue-driven microservices" license = { text = "MIT" } requires-python = ">=3.8" diff --git a/slingshot_microservice/typing.py b/slingshot_microservice/typing.py index 2ff2b90..08e4c95 100644 --- a/slingshot_microservice/typing.py +++ b/slingshot_microservice/typing.py @@ -33,10 +33,20 @@ class ProcessFn(Protocol): ) -> Generator[tuple[int, CaseVariable], None, None]: ... +class StartupFn(Protocol): + def __call__( + self, + read_file: ReadFileFn, + write_file: WriteFileFn, + connection: SqlAlchemyConnection, + ) -> Generator[tuple[int, CaseVariable], None, None]: ... + + __all__ = [ "CaseVariable", "ReadFileFn", "WriteHandle", "WriteFileFn", "ProcessFn", + "StartupFn", ] diff --git a/src/lib.rs b/src/lib.rs index 8a334e1..3390b6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,6 +49,12 @@ type ProcessFn = dyn Fn(u64, Arc, Arc) -> Result, Arc) -> Result, AnyError> + + Send + + Sync + + 'static; + static REQUEST_FILE_CONTEXT_COUNTER: AtomicU64 = AtomicU64::new(1); static DATABASE_URL_CACHE: OnceLock = OnceLock::new(); static DATABASE_POOL_CACHE: OnceLock>> = OnceLock::new(); @@ -90,6 +96,7 @@ enum CaseKey { pub struct Microservice { name: String, config_host: String, + startup: Arc, process: Arc, } @@ -98,32 +105,60 @@ impl Microservice { fn new_case_key( name: impl Into, config_host: impl Into, + startup: Arc, process: Arc, ) -> Self { init_tracing(); Self { name: name.into(), config_host: config_host.into(), + startup, process, } } /// Create a new microservice runtime. /// - /// `process` accepts an inbound request ID, a `read_file` function, and a - /// `write_file` function, and a mutable PostgreSQL connection, and then - /// returns a list of + /// `startup` runs once when the service starts, before queue consumption, + /// and receives `read_file`, `write_file`, and a mutable PostgreSQL + /// connection. + /// + /// `process` accepts an inbound request ID, `read_file`, `write_file`, and + /// a mutable PostgreSQL connection, and then returns a list of /// `(result_id, case_variable)` tuples. Case variables can be any /// serializable primitive, such as `String`, `bool`, or integers. - pub fn new(name: impl Into, config_host: impl Into, process: F) -> Self + pub fn new( + name: impl Into, + config_host: impl Into, + startup: S, + process: F, + ) -> Self where F: Fn(u64, &ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result, AnyError> + Send + Sync + 'static, + S: Fn(&ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result, AnyError> + + Send + + Sync + + 'static, C: Serialize + 'static, { init_tracing(); + let startup_wrapper = move | + read_file: Arc, + write_file: Arc, + | -> Result, AnyError> { + let mut connection = establish_pg_connection().map_err(|e| { + format!( + "failed to checkout PostgreSQL connection from pool for startup callback: {}", + e + ) + })?; + let outputs = startup(read_file.as_ref(), write_file.as_ref(), &mut connection)?; + map_outputs_with_case_key(outputs) + }; + let process_wrapper = move | request: u64, read_file: Arc, @@ -136,19 +171,13 @@ impl Microservice { ) })?; let outputs = process(request, read_file.as_ref(), write_file.as_ref(), &mut connection)?; - let mut mapped = Vec::with_capacity(outputs.len()); - for (id, case) in outputs { - let value = serde_json::to_value(case) - .map_err(|e| format!("case variable must be serializable to JSON: {}", e))?; - mapped.push((id, case_key_from_value(&value)?)); - } - - Ok(mapped) + map_outputs_with_case_key(outputs) }; Self { name: name.into(), config_host: config_host.into(), + startup: Arc::new(startup_wrapper), process: Arc::new(process_wrapper), } } @@ -190,6 +219,61 @@ impl Microservice { declare_queues(&channel, &inbound_queue, &route_map).await?; + let startup_request_id = 0_u64; + let startup_context = Arc::new(Mutex::new(RequestFileContext::new(startup_request_id)?)); + let startup_read_context = Arc::clone(&startup_context); + let startup_write_context = Arc::clone(&startup_context); + let startup_s3_read_client = Arc::clone(&s3_client); + let startup_read_bucket_cache = Arc::clone(&bucket_name_cache); + let startup_write_bucket_cache = Arc::clone(&bucket_name_cache); + let startup_config_host = self.config_host.clone(); + let startup_microservice_name = self.name.clone(); + let startup_read_config_host = startup_config_host.clone(); + let startup_read_microservice_name = startup_microservice_name.clone(); + + let startup_read_file: Arc = Arc::new(move |key: &str, id: u64| -> Result { + let bucket = resolve_bucket_name( + &startup_read_config_host, + &startup_read_microservice_name, + &startup_read_bucket_cache, + key, + )?; + let mut guard = startup_read_context + .lock() + .map_err(|e| format!("file context lock poisoned for startup read_file: {}", e))?; + guard.read_file(startup_s3_read_client.as_ref(), &bucket, id) + }); + + let startup_write_file: Arc = Arc::new(move |key: &str, id: u64| -> Result { + let bucket = resolve_bucket_name( + &startup_config_host, + &startup_microservice_name, + &startup_write_bucket_cache, + key, + )?; + let mut guard = startup_write_context + .lock() + .map_err(|e| format!("file context lock poisoned for startup write_file: {}", e))?; + guard.write_file(&bucket, id) + }); + + let startup_outputs = (self.startup)( + Arc::clone(&startup_read_file), + Arc::clone(&startup_write_file), + ) + .map_err(|e| format!("startup callback failed: {}", e))?; + { + let mut guard = startup_context + .lock() + .map_err(|e| format!("file context lock poisoned for startup finalize: {}", e))?; + guard + .finalize(s3_client.as_ref()) + .map_err(|e| format!("startup finalize/upload failed: {}", e))? + } + publish_outputs(&channel, startup_outputs, &route_map) + .await + .map_err(|e| format!("startup publish failed: {}", e))?; + let mut consumer = channel .basic_consume( &inbound_queue, @@ -700,6 +784,17 @@ fn case_key_from_value(case_value: &Value) -> Result { } } +fn map_outputs_with_case_key(outputs: Vec<(u64, C)>) -> Result, AnyError> { + let mut mapped = Vec::with_capacity(outputs.len()); + for (id, case) in outputs { + let value = serde_json::to_value(case) + .map_err(|e| format!("case variable must be serializable to JSON: {}", e))?; + mapped.push((id, case_key_from_value(&value)?)); + } + + Ok(mapped) +} + async fn declare_queues( channel: &Channel, inbound_queue: &str, @@ -955,6 +1050,7 @@ fn create_python_sqlalchemy_engine() -> Result, AnyError> { struct PyMicroservice { name: String, config_host: String, + startup: Py, process: Py, } @@ -962,8 +1058,12 @@ struct PyMicroservice { #[pymethods] impl PyMicroservice { #[new] - fn new(name: String, config_host: String, process: Py) -> PyResult { + fn new(name: String, config_host: String, startup: Py, process: Py) -> PyResult { Python::with_gil(|py| { + if !startup.bind(py).is_callable() { + return Err(PyTypeError::new_err("startup must be callable")); + } + if !process.bind(py).is_callable() { return Err(PyTypeError::new_err("process must be callable")); } @@ -971,19 +1071,26 @@ impl PyMicroservice { Ok(Self { name, config_host, + startup, process, }) }) } fn start(&self) -> PyResult<()> { + let startup = Python::with_gil(|py| self.startup.clone_ref(py)); let process = Python::with_gil(|py| self.process.clone_ref(py)); let engine = create_python_sqlalchemy_engine().map_err(any_error_to_py)?; + let startup_engine = Python::with_gil(|py| engine.clone_ref(py)); + let process_engine = Python::with_gil(|py| engine.clone_ref(py)); let microservice = Microservice::new_case_key( self.name.clone(), self.config_host.clone(), + Arc::new(move |read_file, write_file| { + run_python_startup(&startup, &startup_engine, read_file, write_file) + }), Arc::new(move |request, read_file, write_file| { - run_python_process(&process, &engine, request, read_file, write_file) + run_python_process(&process, &process_engine, request, read_file, write_file) }), ); @@ -991,6 +1098,57 @@ impl PyMicroservice { } } + +#[cfg(feature = "python")] +fn run_python_startup( + startup: &Py, + engine: &Py, + read_file: Arc, + write_file: Arc, +) -> Result, AnyError> { + Python::with_gil(|py| -> Result, AnyError> { + let py_read = Py::new(py, PyReadFileFn { inner: read_file }) + .map_err(|e| format!("failed to build Python ReadFileFn wrapper: {}", e))?; + let py_write = Py::new(py, PyWriteFileFn { inner: write_file }) + .map_err(|e| format!("failed to build Python WriteFileFn wrapper: {}", e))?; + let connection = engine + .bind(py) + .call_method0("connect") + .map_err(|e| format!("failed to open SQLAlchemy connection: {}", e))?; + + let outputs_result = (|| -> Result, AnyError> { + let returned = startup + .call1(py, (py_read, py_write, connection.clone())) + .map_err(|e| format!("Python startup callback failed: {}", e))?; + + let iter = returned + .bind(py) + .try_iter() + .map_err(|e| format!("startup return value must be iterable: {}", e))?; + + let mut outputs = Vec::new(); + for item in iter { + let item = item.map_err(|e| format!("failed to iterate startup outputs: {}", e))?; + let (id, case_obj): (u64, Py) = item + .extract() + .map_err(|e| format!("each startup output must be a tuple (int, case): {}", e))?; + let case = case_key_from_py_value(case_obj.bind(py)) + .map_err(|e| format!("invalid case variable: {}", e))?; + outputs.push((id, case)); + } + + Ok(outputs) + })(); + + let close_result = connection.call_method0("close"); + if let Err(err) = close_result { + return Err(format!("failed to close SQLAlchemy connection: {}", err).into()); + } + + outputs_result + }) +} + #[cfg(feature = "python")] #[pymodule] fn _native(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {