import io import zipfile from datetime import datetime, timezone from app import models class ShorewallGenerator: def __init__(self, config: models.Config) -> None: self._config = config self._ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") def _header(self, filename: str) -> str: return ( f"# {filename} — generated by shorefront " f"| config: {self._config.name} " f"| {self._ts}\n" ) def _col(self, *values: str, width: int = 16) -> str: return "".join(v.ljust(width) for v in values).rstrip() + "\n" def zones(self) -> str: lines = [self._header("zones"), "#ZONE".ljust(16) + "TYPE".ljust(16) + "OPTIONS\n"] for z in self._config.zones: lines.append(self._col(z.name, z.type, z.options or "-")) return "".join(lines) def interfaces(self) -> str: lines = [self._header("interfaces"), "#ZONE".ljust(16) + "INTERFACE".ljust(16) + "BROADCAST".ljust(16) + "OPTIONS\n"] for iface in self._config.interfaces: zone = iface.zone.name if iface.zone else "-" lines.append(self._col(zone, iface.name, iface.broadcast or "-", iface.options or "-")) return "".join(lines) def policy(self) -> str: lines = [ self._header("policy"), "#SOURCE".ljust(16) + "DEST".ljust(16) + "POLICY".ljust(16) + "LOG LEVEL".ljust(16) + "LIMIT:BURST".ljust(20) + "CONNLIMIT:MASK\n", ] for p in sorted(self._config.policies, key=lambda x: x.position): src = p.src_zone.name if p.src_zone else "all" dst = p.dst_zone.name if p.dst_zone else "all" lines.append(self._col( src, dst, p.policy, p.log_level or "-", p.limit_burst or "-", p.connlimit_mask or "-", width=16, )) return "".join(lines) def rules(self) -> str: lines = [ self._header("rules"), "#ACTION".ljust(16) + "SOURCE".ljust(24) + "DEST".ljust(24) + "PROTO".ljust(10) + "DPORT".ljust(16) + "SPORT".ljust(16) + "ORIGDEST".ljust(20) + "RATE".ljust(16) + "USER".ljust(16) + "MARK".ljust(12) + "CONNLIMIT".ljust(14) + "TIME".ljust(20) + "HEADERS".ljust(16) + "SWITCH".ljust(16) + "HELPER\n", "SECTION NEW\n", ] for r in sorted(self._config.rules, key=lambda x: x.position): src = (r.src_zone.name if r.src_zone else "all") + (f":{r.src_ip}" if r.src_ip else "") dst = (r.dst_zone.name if r.dst_zone else "all") + (f":{r.dst_ip}" if r.dst_ip else "") lines.append(self._col( r.action, src, dst, r.proto or "-", r.dport or "-", r.sport or "-", r.origdest or "-", r.rate_limit or "-", r.user_group or "-", r.mark or "-", r.connlimit or "-", r.time or "-", r.headers or "-", r.switch_name or "-", r.helper or "-", width=16, )) return "".join(lines) def hosts(self) -> str: lines = [self._header("hosts"), "#ZONE".ljust(16) + "HOSTS\n"] for h in self._config.host_entries: hosts_val = f"{h.interface}:{h.subnet}" lines.append(self._col(h.zone.name, hosts_val, h.options or "-", width=16)) return "".join(lines) def params(self) -> str: lines = [self._header("params")] for p in self._config.params: lines.append(f"{p.name}={p.value}\n") return "".join(lines) def snat(self) -> str: lines = [ self._header("snat"), "#ACTION".ljust(24) + "SOURCE".ljust(24) + "DEST".ljust(20) + "PROTO".ljust(10) + "PORT".ljust(16) + "IPSEC".ljust(16) + "MARK".ljust(12) + "USER/GROUP".ljust(16) + "SWITCH".ljust(16) + "ORIGDEST".ljust(20) + "PROBABILITY\n", ] for m in self._config.snat_entries: action = f"SNAT:{m.to_address}" if m.to_address else "MASQUERADE" lines.append(self._col( action, m.source_network, m.out_interface, m.proto or "-", m.port or "-", m.ipsec or "-", m.mark or "-", m.user_group or "-", m.switch_name or "-", m.origdest or "-", m.probability or "-", width=16, )) return "".join(lines) def as_json(self) -> dict: return { "zones": self.zones(), "interfaces": self.interfaces(), "policy": self.policy(), "rules": self.rules(), "snat": self.snat(), "hosts": self.hosts(), "params": self.params(), } def as_zip(self) -> bytes: buf = io.BytesIO() with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: zf.writestr("zones", self.zones()) zf.writestr("interfaces", self.interfaces()) zf.writestr("policy", self.policy()) zf.writestr("rules", self.rules()) zf.writestr("snat", self.snat()) zf.writestr("hosts", self.hosts()) zf.writestr("params", self.params()) return buf.getvalue()