diff --git a/.gitignore b/.gitignore index da7fbc2..d0a5a31 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .direnv __pycache__/ +devices.db diff --git a/app.py b/app.py new file mode 100644 index 0000000..5edce19 --- /dev/null +++ b/app.py @@ -0,0 +1,260 @@ +from __future__ import annotations +from contextlib import asynccontextmanager +from typing import Annotated +from fastapi import Depends, FastAPI, HTTPException +from sqlmodel import SQLModel, Session, create_engine, delete, select + +import netbrite as nb +from db import ( + MessageDB, + NetBriteBase, + NetBriteDB, + NetBritePublic, + ZoneBase, + ZoneDB, + ZonePublic, +) + +DB_URL = "sqlite:///devices.db" +engine = create_engine(DB_URL, connect_args={"check_same_thread": False}) + + +def get_session(): + with Session(engine) as session: + yield session + + +SessionDep = Annotated[Session, Depends(get_session)] + + +@asynccontextmanager +async def lifespan(_: FastAPI): + SQLModel.metadata.create_all(engine) + load_devices_from_db() + yield + + +app = FastAPI(lifespan=lifespan) + +active_devices: dict[int, nb.NetBrite] = {} +device_status: dict[int, bool] = {} + + +# ---------- helper ---------- +def load_devices_from_db() -> None: + with Session(engine) as session: + for device in session.exec(select(NetBriteDB)).all(): + load_device(device, session) + + +def load_device(device: NetBriteDB, session: SessionDep): + id = device.id or 0 + try: + active_devices[id] = nb.NetBrite(device.address, device.port) + device_status[id] = True + load_zones_id(session, id, active_devices[id]) + except nb.NetbriteConnectionException as exc: + device_status[id] = False + print(f"Could not connect to {device.address}:{device.port} — {exc}") + + +def load_zones_id(session: Session, device_id: int, net_dev: nb.NetBrite): + statement = select(ZoneDB).where(ZoneDB.netbrite_id == device_id) + zones = list(session.exec(statement)) + load_zones(zones, net_dev) + + +def load_zones(zones_in: list[ZoneDB], net_dev: nb.NetBrite) -> None: + zones: dict[str, nb.Zone] = {} + + for zone in zones_in: + msg = zone.default_message + + default_msg = ( + nb.Message( + activation_delay=msg.activation_delay, + display_delay=msg.display_delay, + priority=msg.priority, + text=msg.text, + ttl=msg.ttl, + ) + if msg + else nb.Message(f"Zone {zone.name}") + ) + + zones[zone.name] = nb.Zone( + x=zone.x, + y=zone.y, + width=zone.width, + height=zone.height, + scroll_speed=zone.scroll_speed, + pause_duration=zone.pause_duration, + volume=zone.volume, + default_font=zone.default_font, + default_color=zone.default_color, + initial_text=default_msg, + ) + + if zones: + net_dev.zones(zones) + + +def create_default_zone(session: Session, device_id: int) -> None: + zone = ZoneDB( + name="0", + x=0, + y=0, + width=120, + height=7, + netbrite_id=device_id, + ) + msg = MessageDB( + text="{erase}Welcome", + ) + session.add(msg) + session.add(zone) + + session.commit() + session.refresh(zone) + session.refresh(msg) + + zone.default_message_id = msg.id + session.commit() + + +# ---------- routes ---------- +@app.post("/api/device", response_model=NetBritePublic) +def create_device(device: NetBriteBase, session: SessionDep): + if session.exec( + select(NetBriteDB).where(NetBriteDB.address == device.address) + ).first(): + raise HTTPException(400, "Device already exists") + + db_device = NetBriteDB.model_validate(device) + session.add(db_device) + session.commit() + session.refresh(db_device) + + load_device(db_device, session) + return db_device + + +@app.get("/api/devices", response_model=list[NetBritePublic]) +def get_devices(session: SessionDep): + devices: list[NetBritePublic] = [] + + for device in session.exec(select(NetBriteDB)).all(): + device = NetBritePublic.model_validate( + device, update={"active": device_status.get(device.id or 0) or False} + ) + devices.append(device) + + return devices # FIXME: implement active + + +@app.post("/api/devices/{device_id}", response_model=NetBritePublic) +def update_device(device_id: int, updated_device: NetBriteBase, session: SessionDep): + db_dev = session.get(NetBriteDB, device_id) + if not db_dev: + raise HTTPException(404, "Device not found") + + db_dev.port = updated_device.port + db_dev.address = updated_device.address + + return 200 + + +# TODO: implement me +@app.post("/api/devices/{device_id}/reconnect") +def reconnect_device(device_id: int, session: SessionDep): + db_dev = session.get(NetBriteDB, device_id) + if not db_dev: + raise HTTPException(404, "Device not found") + + try: + active_devices[device_id] = nb.NetBrite(db_dev.address, db_dev.port) + device_status[device_id] = True + load_zones_id(session, device_id, active_devices[device_id]) + return 200 + except nb.NetbriteConnectionException as exc: + device_status[device_id] = False + raise HTTPException(400, str(exc)) + + +@app.delete("/api/devices/{device_id}") +def delete_device(device_id: int, session: SessionDep): + db_dev = session.get(NetBriteDB, device_id) + if not db_dev: + raise HTTPException(404, "Device not found") + + delete: list[MessageDB | ZoneDB | NetBriteDB] = [db_dev] + for zone in db_dev.zones: + if zone.default_message != None: + delete.append(zone.default_message) + delete.append(zone) + + if device_id in active_devices: + active_devices[device_id].sock.close() + del active_devices[device_id] + + for i in delete: + session.delete(i) + + session.commit() + return 200 + + +@app.post("/api/devices/{device_id}/zones", response_model=ZonePublic) +def create_zone(device_id: int, body: ZoneBase, session: SessionDep): + device = session.get(NetBriteDB, device_id) + if not device: + raise HTTPException(404, "Device not found") + + zone = ZoneDB.model_validate(body) + msg = MessageDB( + text="{erase}Welcome", + ) + session.add(zone) + session.add(msg) + + session.commit() + session.refresh(zone) + session.refresh(msg) + + zone.default_message_id = msg.id + session.commit() + + if not device.id in active_devices: + raise HTTPException(503, "Device not active") + + try: + load_zones(device.zones, active_devices[device.id]) + except nb.NetbriteTransferException: + raise HTTPException(503, "Device not active") + + +@app.get("/api/devices/{device_id}/zones", response_model=list[ZonePublic]) +def get_zones(device_id: int, session: SessionDep): + device = session.get(NetBriteDB, device_id) + if not device: + raise HTTPException(404, "Device not found") + + return device.zones + + +@app.delete( + "/api/zone/{zone_id}", +) +def delete_zone(zone_id: int, session: SessionDep): + zone = session.get(ZoneDB, zone_id) + if not zone: + raise HTTPException(404, "Zone not found") + + message = zone.default_message + if message: + session.delete(message) + + session.delete(zone) + session.commit() + + return 200 diff --git a/db.py b/db.py new file mode 100644 index 0000000..a62f3f6 --- /dev/null +++ b/db.py @@ -0,0 +1,77 @@ +# from __future__ import annotations +from sqlmodel import Field, Relationship, SQLModel +from netbrite import Colors, Fonts, Priorities, ScrollSpeeds, Message + + +# --- Message --- +class MessageBase(SQLModel): + activation_delay: int = 0 + display_delay: int = 0 + display_repeat: int = 0 + priority: Priorities = Priorities.OVERRIDE + text: str = "" + ttl: int = 0 + + +class MessageDB(MessageBase, table=True): + id: int | None = Field(default=None, primary_key=True) + zone: "ZoneDB" = Relationship( # pyright: ignore[reportAny] + back_populates="default_message" + ) + + +class MessagePublic(MessageBase): + id: int + + +# --- Device --- +class NetBriteBase(SQLModel): + address: str = Field(unique=True, index=True) + port: int = 700 + + +class NetBriteDB(NetBriteBase, table=True): + id: int | None = Field(default=None, primary_key=True) + zones: list["ZoneDB"] = Relationship( # pyright: ignore[reportAny] + back_populates="netbrite" + ) + + +class NetBritePublic(NetBriteBase): + id: int + zones: list["ZoneDB"] + active: bool + + +# --- Zone --- +class ZoneBase(SQLModel): + name: str + x: int + y: int + width: int + height: int + scroll_speed: ScrollSpeeds = ScrollSpeeds.NORMAL + pause_duration: int = 1000 + volume: int = 4 + default_font: Fonts = Fonts.NORMAL_7 + default_color: Colors = Colors.RED + + default_message_id: int | None = Field(default=None, foreign_key="messagedb.id") + netbrite_id: int = Field(default=None, foreign_key="netbritedb.id") + + +class ZoneDB(ZoneBase, table=True): + id: int | None = Field(default=None, primary_key=True) + + default_message: MessageDB | None = Relationship( # pyright: ignore[reportAny] + back_populates="zone" + ) + netbrite: NetBriteDB = Relationship( # pyright: ignore[reportAny] + back_populates="zones" + ) + + +class ZonePublic(ZoneBase): + id: int + default_message: MessagePublic + # netbrite: NetBritePublic diff --git a/flake.nix b/flake.nix index 53bf9e0..ac1634a 100644 --- a/flake.nix +++ b/flake.nix @@ -15,7 +15,7 @@ nativeBuildInputs = [ pkgs.entr pkgs.fastapi-cli - (pkgs.python3.withPackages (x: [x.crc x.fastapi])) + (pkgs.python3.withPackages (x: [x.crc x.fastapi x.sqlmodel x.sqlalchemy])) ]; }; }; diff --git a/netbrite.py b/netbrite.py index ba97efb..1a39dbb 100644 --- a/netbrite.py +++ b/netbrite.py @@ -10,6 +10,14 @@ import re DEFAULT_PORT = 700 +class NetbriteConnectionException(Exception): + pass + + +class NetbriteTransferException(Exception): + pass + + class Colors(Enum): RED = 0x01 GREEN = 0x02 @@ -101,7 +109,7 @@ class Message: (rb"\{right\}", b"\x10\x28"), (rb"\{pause\}", b"\x10\x05"), (rb"\{erase\}", b"\x10\x03"), - (rb"\{serial\}", b"\x10\x09"), + (rb"\{serialnum\}", b"\x10\x09"), (rb"\{bell\}", b"\x10\x05"), (rb"\{red\}", b"\x10\x0c" + pack("B", Colors.RED.value)), (rb"\{green\}", b"\x10\x0c" + pack("B", Colors.GREEN.value)), @@ -183,16 +191,26 @@ class NetBrite: try: self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(5000) + self.sock.settimeout(2) self.connect() except OSError as e: - raise ConnectionError(f"Error while opening network socket. {e}") + raise NetbriteConnectionException( + f"Error while opening network socket. {e}" + ) def connect(self): - self.sock.connect((self.address, self.port)) + try: + self.sock.connect((self.address, self.port)) + except OSError as e: + raise NetbriteConnectionException( + f"Error while opening network socket. {e}" + ) def tx(self, pkt: bytes): - _ = self.sock.send(pkt) + try: + _ = self.sock.send(pkt) + except OSError as e: + raise NetbriteTransferException(f"Error while opening network socket. {e}") def message(self, msg: Message, zoneName: str): z = self.zones_list.get(zoneName)