package quant.rich.emoney.controller.api; import java.io.Serializable; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.lang.NonNull; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.protobuf.nano.MessageNano; import jakarta.annotation.PostConstruct; import org.apache.commons.lang3.StringUtils; import org.reflections.Reflections; import lombok.Data; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import nano.BaseResponse.Base_Response; import quant.rich.emoney.annotation.ResponseDecodeExtension; import quant.rich.emoney.entity.sqlite.ProtocolMatch; import quant.rich.emoney.exception.RException; import quant.rich.emoney.pojo.dto.EmoneyConvertResult; import quant.rich.emoney.pojo.dto.EmoneyProtobufBody; import quant.rich.emoney.service.sqlite.ProtocolMatchService; import quant.rich.emoney.util.SpringBeanDetector; import quant.rich.emoney.util.SpringContextHolder; /** * 益盟 ProtocolBuf 报文解析 API 控制器 */ @RestController @RequestMapping("/api/v1/proto") @Slf4j public class ProtoDecodeControllerV1 { @Autowired ProtocolMatchService protocolMatchService; @Autowired Reflections reflections; Map> responseDecodeExtensions = new HashMap>(); @Data @RequiredArgsConstructor private static class MethodInfo { final Method method; final Class declaringClass; final Integer order; Object instance; } @PostConstruct void postConstruct() { // Reflections 扫描所有注解并根据 protocolId 和 order 排序 Set methods = reflections.getMethodsAnnotatedWith(ResponseDecodeExtension.class); for (Method m : methods) { MethodInfo info; ResponseDecodeExtension ex = m.getAnnotation(ResponseDecodeExtension.class); String protocolId = ex.protocolId(); Integer order = ex.order(); // 判断 method 是否为单参数接受 JsonNode 的方法 Class[] parameterTypes = m.getParameterTypes(); Class declaringClass = m.getDeclaringClass(); if (parameterTypes.length != 1 || parameterTypes[0] != JsonNode.class) { log.warn("方法 {}#{} 不为类型为 JsonNode 的单参数方法,暂不支持作为解码额外选项", declaringClass.getSimpleName(), m.getName(), declaringClass.getSimpleName()); continue; } // 判断 method 是否是静态 if (Modifier.isStatic(m.getModifiers())) { info = new MethodInfo(m, null, order); } else { if (!SpringBeanDetector.isSpringManagedClass(declaringClass)) { log.warn("方法 {} 所属类 {} 不归属于 Spring 管理,目前暂不支持作为解码额外选项", m.getName(), declaringClass.getSimpleName()); continue; } info = new MethodInfo(m, declaringClass, order); } List list = responseDecodeExtensions.get(protocolId); if (list == null) { list = new ArrayList<>(); list.add(info); responseDecodeExtensions.put(protocolId, list); } else { list.add(info); } } for (List list : responseDecodeExtensions.values()) { list.sort(Comparator.comparingInt(info -> info.getOrder())); } log.debug("ResponseDecodeExtension: 共载入 {} 个 ProtocolID 的 {} 个方法", responseDecodeExtensions.keySet().size(), responseDecodeExtensions.values().size()); } /** * 解析 emoney protobuf 的请求 * @param * @param body * @return */ @SuppressWarnings("unchecked") @PostMapping("/request/decode") public EmoneyConvertResult requestDecode( @RequestBody(required=true) @NonNull EmoneyProtobufBody body) { Integer protocolId = body.getProtocolId(); if (Objects.isNull(protocolId)) { throw RException.badRequest("protocolId 不能为 null"); } ProtocolMatch match = protocolMatchService.getById(protocolId); if (Objects.isNull(match) || StringUtils.isBlank(match.getClassName())) { throw RException .badRequest("暂无对应 protocolId = ", protocolId, " 的记录,可等待 response decoder 收集到后再重试"); } String className = new StringBuilder() .append("nano.") .append(match.getClassName()) .append("Request$") .append(match.getClassName()) .append("_Request") .toString(); // IndexInflow -> nano.IndexInflowRequest$IndexInflow_Request Class clazz; try { clazz = (Class)Class.forName(className); } catch (Exception e) { String msg = new StringBuilder() .append("无法根据给定的 protocolId = ") .append(protocolId) .append(", className = ") .append(className) .append("找到对应类").toString(); log.warn(msg, e); throw RException.internalServerError(msg); } byte[] buf; try { buf = body.protocolBodyToByte(); } catch (Exception e) { throw RException.badRequest("转换 protocolBody 错误"); } try { U nano = (U)MessageNano.mergeFrom((MessageNano) clazz.getDeclaredConstructor().newInstance(), buf); return EmoneyConvertResult .ok(new ObjectMapper().valueToTree(nano)) .setProtocolId(protocolId) .setSupposedClassName(className); } catch (Exception e) { String msg = new StringBuilder() .append("转换为类 ") .append(className) .append(" 时错误").toString(); log.warn(msg, e); return EmoneyConvertResult.error(msg) .setProtocolId(protocolId) .setSupposedClassName(className); } } @SuppressWarnings("unchecked") @PostMapping("/response/decode") public EmoneyConvertResult responseDecode( @RequestBody(required=false) @NonNull EmoneyProtobufBody body) { Integer protocolId = body.getProtocolId(); ProtocolMatch match = null; if (Objects.isNull(protocolId)) { log.warn("protocolId 为空 null, 无法更新 protocolMatch"); } else { match = protocolMatchService.getById(protocolId); if (Objects.isNull(match)) { match = new ProtocolMatch().setProtocolId(protocolId); } } byte[] buf; try { buf = body.protocolBodyToByte(); } catch (Exception e) { throw RException.badRequest("转换 protocolBody 错误"); } Base_Response baseResponse; try { baseResponse = Base_Response.parseFrom(buf); } catch (Exception e) { String msg = new StringBuilder() .append("转换 BaseResponse 发生错误") .toString(); log.warn(msg, e); return EmoneyConvertResult .error(msg) .setProtocolId(protocolId); } Class clazz; String className = baseResponse.detail.getTypeUrl() .replace("type.googleapis.com/", ""); String rawClassName = className.substring(0, className.lastIndexOf('_')); if (Objects.nonNull(match) && StringUtils.isBlank(match.getClassName())) { match.setClassName(rawClassName); protocolMatchService.saveOrUpdate(match); } className = new StringBuilder() .append("nano.") .append(rawClassName) .append("Response$") .append(className) .toString(); try { clazz = (Class)Class.forName(className); } catch (Exception e) { String msg = new StringBuilder() .append("无法根据给定的 protocolId = ") .append(protocolId) .append(", className = ") .append(className) .append("找到对应类").toString(); log.warn(msg, e); throw RException.internalServerError(msg); } try { U nano = (U)MessageNano.mergeFrom( (MessageNano)clazz.getDeclaredConstructor().newInstance(), baseResponse.detail.getValue()); JsonNode jo = new ObjectMapper().valueToTree(nano); // 查找 ResponseDecodeExtension List methodInfos = responseDecodeExtensions.get(protocolId.toString()); if (methodInfos != null) { for (MethodInfo methodInfo : methodInfos) { if (methodInfo.getInstance() != null) { // instance 不为 null 则说明是已经取到的 spring bean, 直接调用 methodInfo.getMethod().invoke(methodInfo.getInstance(), jo); } else if (methodInfo.getDeclaringClass() != null) { // 获取 spring 管理的实例类 Object instance = SpringContextHolder.getBean(methodInfo.getDeclaringClass()); methodInfo.getMethod().invoke(instance, jo); methodInfo.setInstance(instance); } else { // 静态方法直接 invoke methodInfo.getMethod().invoke(null, jo); } } } return EmoneyConvertResult .ok((Serializable)jo) .setProtocolId(protocolId) .setSupposedClassName(className); } catch (Exception e) { String msg = new StringBuilder() .append("转换为类 ") .append(className) .append(" 时错误").toString(); log.warn(msg, e); return EmoneyConvertResult.error(msg) .setProtocolId(protocolId) .setSupposedClassName(className); } } }