use std::convert::Infallible; use std::net::SocketAddr; use std::path::PathBuf; use std::rc::Rc; use std::sync::Arc; use hyper::service::{make_service_fn, service_fn}; use hyper::{header, Body, Request, Response, Server, StatusCode}; use rhai::{Dynamic, Engine, FnPtr, Map, NativeCallContext, Scope}; use structopt::StructOpt; mod utils; #[derive(Debug, StructOpt)] struct Opt { #[structopt(long, parse(from_os_str), default_value = "narchttpd.rhai")] config_script: PathBuf, } fn make_engine() -> Engine { let mut engine = Engine::new(); engine.register_type_with_name::>>("Request"); engine.register_type::(); engine.register_fn( "handle_request_serve_static", utils::serve_static::handle_request, ); engine.register_fn("serve_static", utils::serve_static::serve_static); engine.register_fn( "handle_request_proxy_child", utils::proxy_child::handle_request, ); engine.register_fn("proxy_child", utils::proxy_child::proxy_child); engine } fn get_config_scope(engine: &Engine, opt: &Opt) -> Scope<'static> { let mut ast = engine.compile_file(opt.config_script.clone()).unwrap(); let mut scope = Scope::new(); scope.push("http_ports", [80]); scope.push("https_ports", [443]); scope.push("domains", Map::new()); let export_ast = engine .compile("export http_ports, https_ports, domains;") .unwrap(); ast.combine(export_ast); let _: () = engine.eval_ast_with_scope(&mut scope, &ast).unwrap(); scope } #[tokio::main] async fn main() { let opt = Arc::new(Opt::from_args()); let engine = make_engine(); let scope = get_config_scope(&engine, &opt); let http_ports: rhai::Array = scope.get_value("http_ports").unwrap(); let http_ports: Vec = http_ports .into_iter() .map(|x| x.as_int().unwrap() as u16) .collect(); let https_ports: rhai::Array = scope.get_value("https_ports").unwrap(); let https_ports: Vec = https_ports .into_iter() .map(|x| x.as_int().unwrap() as u16) .collect(); assert!(https_ports.is_empty(), "HTTPS is complicated oops"); // TODO learn hyper let do_response = move |ctx: &NativeCallContext, domains: Map, req: Request| { for (domain, handler) in &domains { let req_domain = req.headers().get(header::HOST).unwrap(); if domain.as_str() == req_domain { eprintln!("request {:?} matched domain {}", req, domain); // matched! let handler: FnPtr = handler.clone_cast(); let args = [Dynamic::from(Rc::new(req))]; let result = handler.call_dynamic(ctx, None, args).unwrap(); let result: Rc> = result.cast(); return Rc::try_unwrap(result).unwrap(); } } Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not Found")) .unwrap() }; let addr = http_ports.into_iter().flat_map(|port| { ["127.0.0.1", "::1"] .iter() .map(move |ip| SocketAddr::new(ip.parse().unwrap(), port)) }); let make_svc = make_service_fn(move |_conn| async move { Ok::<_, Infallible>(service_fn(move |req| async move { let handle = tokio::runtime::Handle::current(); let response = handle .spawn_blocking(move || { let opt = Opt::from_args(); let engine = make_engine(); let scope = get_config_scope(&engine, &opt); let request_handler_context = NativeCallContext::new(&engine, "handle_request", &[]); let domains: Map = scope.get_value("domains").unwrap(); do_response(&request_handler_context, domains, req) }) .await .unwrap(); Ok::<_, Infallible>(response) })) }); // TODO uhh let mut addr = addr; let addr = addr.nth(0).unwrap(); let server = Server::bind(&addr).serve(make_svc); if let Err(e) = server.await { eprintln!("server error: {}", e); } }