aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
blob: 09515183a04fc484826e0787d41d5f3ecb853111 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::convert::Infallible;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::rc::Rc;
use std::sync::Arc;

use hyper::service::{make_service_fn, service_fn};
use hyper::{header, Body, Request, Response, Server, StatusCode};
use rhai::{Dynamic, Engine, FnPtr, Map, NativeCallContext, Scope};
use structopt::StructOpt;

mod utils;

#[derive(Debug, StructOpt)]
struct Opt {
    #[structopt(long, parse(from_os_str), default_value = "narchttpd.rhai")]
    config_script: PathBuf,
}

fn make_engine() -> Engine {
    let mut engine = Engine::new();
    engine.register_type_with_name::<Rc<Request<Body>>>("Request");
    engine.register_type::<utils::serve_static::Params>();
    engine.register_fn(
        "handle_request_serve_static",
        utils::serve_static::handle_request,
    );
    engine.register_fn("serve_static", utils::serve_static::serve_static);
    engine.register_fn(
        "handle_request_proxy_child",
        utils::proxy_child::handle_request,
    );
    engine.register_fn("proxy_child", utils::proxy_child::proxy_child);
    engine
}

fn get_config_scope(engine: &Engine, opt: &Opt) -> Scope<'static> {
    let mut ast = engine.compile_file(opt.config_script.clone()).unwrap();

    let mut scope = Scope::new();
    scope.push("http_ports", [80]);
    scope.push("https_ports", [443]);
    scope.push("domains", Map::new());
    let export_ast = engine
        .compile("export http_ports, https_ports, domains;")
        .unwrap();
    ast.combine(export_ast);
    let _: () = engine.eval_ast_with_scope(&mut scope, &ast).unwrap();
    scope
}

#[tokio::main]
async fn main() {
    let opt = Arc::new(Opt::from_args());

    let engine = make_engine();
    let scope = get_config_scope(&engine, &opt);

    let http_ports: rhai::Array = scope.get_value("http_ports").unwrap();
    let http_ports: Vec<u16> = http_ports
        .into_iter()
        .map(|x| x.as_int().unwrap() as u16)
        .collect();
    let https_ports: rhai::Array = scope.get_value("https_ports").unwrap();
    let https_ports: Vec<u16> = https_ports
        .into_iter()
        .map(|x| x.as_int().unwrap() as u16)
        .collect();

    assert!(https_ports.is_empty(), "HTTPS is complicated oops");

    // TODO learn hyper
    let do_response = move |ctx: &NativeCallContext, domains: Map, req: Request<Body>| {
        for (domain, handler) in &domains {
            let req_domain = req.headers().get(header::HOST).unwrap();
            if domain.as_str() == req_domain {
                eprintln!("request {:?} matched domain {}", req, domain);
                // matched!
                let handler: FnPtr = handler.clone_cast();
                let args = [Dynamic::from(Rc::new(req))];
                let result = handler.call_dynamic(ctx, None, args).unwrap();
                let result: Rc<Response<Body>> = result.cast();
                return Rc::try_unwrap(result).unwrap();
            }
        }
        Response::builder()
            .status(StatusCode::NOT_FOUND)
            .body(Body::from("Not Found"))
            .unwrap()
    };
    let addr = http_ports.into_iter().flat_map(|port| {
        ["127.0.0.1", "::1"]
            .iter()
            .map(move |ip| SocketAddr::new(ip.parse().unwrap(), port))
    });

    let make_svc = make_service_fn(move |_conn| async move {
        Ok::<_, Infallible>(service_fn(move |req| async move {
            let handle = tokio::runtime::Handle::current();
            let response = handle
                .spawn_blocking(move || {
                    let opt = Opt::from_args();
                    let engine = make_engine();
                    let scope = get_config_scope(&engine, &opt);
                    let request_handler_context =
                        NativeCallContext::new(&engine, "handle_request", &[]);
                    let domains: Map = scope.get_value("domains").unwrap();
                    do_response(&request_handler_context, domains, req)
                })
                .await
                .unwrap();
            Ok::<_, Infallible>(response)
        }))
    });

    // TODO uhh
    let mut addr = addr;
    let addr = addr.nth(0).unwrap();
    let server = Server::bind(&addr).serve(make_svc);

    if let Err(e) = server.await {
        eprintln!("server error: {}", e);
    }
}