LocationWebSocketHandshakeInterceptor.java 3.1 KB
package com.diligrp.rider.websocket;

import com.diligrp.rider.common.exception.BizException;
import com.diligrp.rider.config.JwtUtil;
import com.diligrp.rider.entity.Substation;
import com.diligrp.rider.mapper.SubstationMapper;
import io.jsonwebtoken.Claims;
import lombok.RequiredArgsConstructor;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import java.util.Map;

@Component
@RequiredArgsConstructor
public class LocationWebSocketHandshakeInterceptor implements HandshakeInterceptor {

    private final JwtUtil jwtUtil;
    private final SubstationMapper substationMapper;

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
                                   WebSocketHandler wsHandler, Map<String, Object> attributes) {
        if (!(request instanceof ServletServerHttpRequest servletRequest)) {
            return false;
        }
        String token = servletRequest.getServletRequest().getParameter("token");
        if (!StringUtils.hasText(token)) {
            token = request.getHeaders().getFirst("Authorization");
        }
        if (!StringUtils.hasText(token)) {
            throw new BizException(700, "请先登录");
        }
        if (token.startsWith("Bearer ")) {
            token = token.substring(7);
        }

        Claims claims = jwtUtil.getAdminClaims(token);
        Object adminIdObj = claims.get("adminId");
        String role = claims.get("role", String.class);
        if (!(adminIdObj instanceof Number) || !StringUtils.hasText(role)) {
            throw new BizException(700, "登录状态失效,请重新登录");
        }

        Long cityId;
        Long adminId = ((Number) adminIdObj).longValue();
        if ("substation".equals(role)) {
            Substation substation = substationMapper.selectById(adminId);
            if (substation == null || substation.getCityId() == null) {
                throw new BizException(700, "分站信息不存在");
            }
            cityId = substation.getCityId();
        } else if ("admin".equals(role)) {
            String cityIdText = servletRequest.getServletRequest().getParameter("cityId");
            if (!StringUtils.hasText(cityIdText)) {
                throw new BizException(400, "管理员连接WebSocket时必须传cityId");
            }
            cityId = Long.parseLong(cityIdText);
        } else {
            throw new BizException(403, "当前角色无权订阅位置推送");
        }

        attributes.put("adminId", adminId);
        attributes.put("role", role);
        attributes.put("cityId", cityId);
        return true;
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
                               WebSocketHandler wsHandler, Exception exception) {
    }
}