#!/usr/bin/env python3 """Migrate Baidu-only controllers to AbstractRecognizeController pattern.""" import os import re from pathlib import Path ROOT = Path(__file__).resolve().parents[1] CTRL_ROOT = ROOT / "api-web/api-interface/src/main/java/com/heyu/api/controller" HANDLE_ROOT = ROOT / "api-third/src/main/java/com/heyu/api/baidu/handle" SKIP = { "BaseController.java", "AbstractRecognizeController.java", "TestController.java", } BIZ_NAMES = { "bus_ticket": "汽车票识别", "invoice": "通用机打发票识别", "seal": "印章识别", "idcard": "身份证识别", "medical_outpatient": "门诊病历识别", } def load_handles(): handles = {} for path in HANDLE_ROOT.rglob("*Handle.java"): text = path.read_text(encoding="utf-8", errors="ignore") m = re.search( r"class\s+(\w+)\s+extends\s+BaiduBaseHandle<([^,>]+),\s*([^>]+)>", text, ) if not m: m = re.search( r"class\s+(\w+)\s+extends\s+(?:BEducationJsonHandle|BaiduFormRawHandle|BaiduMultipartRawHandle)<([^>]+)>", text, ) if m: entry = { "request": m.group(2).strip(), "response": "java.util.Map", "uri": None, "json_body": True, } uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S) if uri_m: entry["uri"] = uri_m.group(1) handles[m.group(1)] = entry continue if not m: m = re.search( r"class\s+(\w+)\s+extends\s+BEducationJsonHandle<([^>]+)>", text, ) if m: handles[m.group(1)] = { "request": m.group(2).strip(), "response": "java.util.Map", "uri": None, "json_body": True, } uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S) if uri_m: handles[m.group(1)]["uri"] = uri_m.group(1) continue uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S) if m and uri_m: handles[m.group(1)] = { "request": m.group(2).strip(), "response": m.group(3).strip(), "uri": uri_m.group(1), "json_body": False, } return handles def is_simple_method(body: str) -> bool: if "for (" in body or "for(" in body: return False if "BaiduOcrResult.raw" in body: return True if ".handle(" in body and "return R.error" in body and "new ArrayList" not in body: return True if ".handle(" in body and "if (bR.isSuccess())" in body and body.count("return") <= 3: return True return False def infer_biz_label(class_name, uri): base = class_name.replace("Controller", "") if base.startswith("Recognize"): base = base[len("Recognize") :] mapping = { "BusTicket": "汽车票识别", "Invoice": "通用机打发票识别", "Seal": "印章识别", "TaxiReceipt": "出租车票识别", "AirTicket": "飞机行程单识别", "BankReceiptNew": "银行回单识别", "BankReceiptNewPro": "银行回单识别(专业版)", "FerryTicket": "船票识别", "ShoppingReceipt": "购物小票识别", "TollInvoice": "过路过桥费发票识别", "QuotaInvoice": "定额发票识别", "MultipleInvoice": "混贴发票识别", "OnlineTaxiItinerary": "网约车行程单识别", "VatInvoiceVerification": "增值税发票核验", "RecognizeVATInvoice": "增值税发票识别", "RecognizeVINCode": "VIN码识别", "RecognizeQrCode": "二维码识别", "LicensePlate": "车牌识别", "IdCard": "身份证识别", "IdCardMulti": "身份证混贴识别", "ForgeryDetection": "图片篡改检测", } if base in mapping: return mapping[base] # camel to words rough words = re.sub(r"([a-z])([A-Z])", r"\1\2", base) return words + "识别" def infer_biz_key(uri, class_name): if uri: tail = uri.rstrip("/").split("/")[-1] if tail: return tail return class_name.replace("Controller", "").lower() def parse_controller(path: Path): text = path.read_text(encoding="utf-8") if "extends AbstractRecognizeController" in text: return None if "com.heyu.api.baidu" not in text: return None if "com.heyu.api.alibaba" in text: return None # skip manual complex for now if path.name in ("MedicalOutpatientController.java",): return "MANUAL" pkg_m = re.search(r"package\s+([\w.]+);", text) pkg = pkg_m.group(1) if pkg_m else "" mapping_m = re.search(r"@RequestMapping\((\{[^}]+\}|\"[^\"]+\")", text) if mapping_m: base_mapping = mapping_m.group(1).strip().strip('"') if base_mapping.startswith("{"): # 取第一个路径作为类级 @RequestMapping inner = re.search(r'"(/[^"]+)"', base_mapping) base_mapping = inner.group(1) if inner else base_mapping else: base_mapping = "" class_m = re.search(r"public class (\w+)", text) class_name = class_m.group(1) annotations = [] if "@NotIntercept" in text: annotations.append("@NotIntercept") if "@Slf4j" in text: has_slf4j = True else: has_slf4j = False methods = [] for m in re.finditer( r"(@\w+(?:\([^)]*\))?\s+)*@(?PRequestMapping|PostMapping)\((?P\{[^}]+\}|\"[^\"]+\")\)\s+" r"((?:@\w+(?:\([^)]*\))?\s+)*)" r"public\s+R(?:<[^>]+>)?\s+(?P\w+)\((?P[^)]*)\)\s*\{", text, re.S, ): mapping_raw = m.group("mapping_raw") map_ann = m.group("mapAnn") method_name = m.group("method_name") params = m.group("params").strip() start = m.end() brace = 1 i = start while i < len(text) and brace: if text[i] == "{": brace += 1 elif text[i] == "}": brace -= 1 i += 1 body = text[start : i - 1] if not is_simple_method(body): return "COMPLEX" methods.append( { "mapping_raw": mapping_raw, "map_ann": map_ann, "method_name": method_name, "params": params, "body": body, } ) handle_injects = re.findall( r"private\s+(\w+Handle)\s+(\w+);", text ) return { "path": path, "pkg": pkg, "class_name": class_name, "base_mapping": base_mapping, "annotations": annotations, "has_slf4j": has_slf4j, "methods": methods, "handles": handle_injects, "text": text, } def extract_handle_from_body(body, handles_meta): m = re.search(r"([A-Z]\w*Handle)\.handle\(", body) if m: return m.group(1) m = re.search(r"([A-Z]\w*Handle)\.(?:check|getContent)\(", body) if m: return m.group(1) m = re.search(r"(\w+)\.handle\(", body) if m: field = m.group(1) for handle_name in handles_meta: bean = handle_name[0].lower() + handle_name[1:] if bean == field: return handle_name return None def generate_method(method, handle_name, field_name, handles_meta, class_name, biz_label, biz_key): handle = handles_meta.get(handle_name) if not handle: return None resp_type = handle["response"] use_map = resp_type in ("Map", "JSONObject", "java.util.Map") resp_type_java = "java.util.Map" if use_map else resp_type if not use_map and resp_type_java.startswith("com."): pass elif not use_map and not resp_type_java[0].isupper(): return None req_type = handle["request"] params = method["params"] param_name = params.split()[-1].replace(",", "").strip() if params else "request" if not param_name or param_name == "void": param_name = "request" uri = handle["uri"] side_key = biz_key # detect request conversion in body pre_code = "" actual_param = param_name body = method["body"] if "new B" in body and "setImageUrl" in body: # IdCard style mapping - keep original body setup setup = [] for line in body.splitlines(): line = line.strip() if line.startswith("B") and "Request" in line and "= new" in line: setup.append(line) elif line.startswith("b") and ".set" in line: setup.append(line) if "handle(" in line: break if setup: b_req_var = setup[0].split()[0] pre_code = "\n ".join(setup) actual_param = b_req_var req_type = b_req_var.split()[0] if "=" in setup[0] else req_type if "new B" in body and "setVerifyNum" in body: setup = [] for line in body.splitlines(): line = line.strip() if ("Request" in line and "= new" in line) or ( line.startswith("bRequest.set") or line.startswith("bRequest.") ): if "handle(" in line: break setup.append(line) if setup: pre_code = "\n ".join(setup) actual_param = setup[0].split()[0] method_mapping = method["mapping_raw"] cache = "@CacheResult\n " if "@CacheResult" in body or True else "" eb_auth = "" if "@EbAuthentication" in body: eb_auth = "@EbAuthentication(tencent = ApiConstants.TENCENT_AUTH)\n " post_mapping = method_mapping.strip() if post_mapping.startswith("{"): post_mapping = post_mapping else: post_mapping = f'"{post_mapping.strip("\"")}"' pre_block = f"{pre_code}\n " if pre_code else "" map_ann = method.get("map_ann", "RequestMapping") if map_ann == "PostMapping": imports_post = True return f""" @{map_ann}({post_mapping}) {cache}public R {method['method_name']}({params}) {{ {pre_block}return executeLegacyBaiduRecognize(LegacyBaiduRecognizeSpec.<{resp_type_java}>builder() .bizLabel("{biz_label}") .bizKey("{side_key}") .uri("{uri}") .sideKey("{side_key}") .targetLabel("{biz_label.replace('识别', '')}") .inputLog(buildInputLog({actual_param})) .respClass({resp_type_java}.class) .paramCheck(() -> {field_name}.check({actual_param})) .requestContent(() -> {field_name}.getContent({actual_param})) .build()); }}""" def generate_controller(info, handles_meta): class_name = info["class_name"] uri = "" if info["handles"]: h0 = info["handles"][0][0] if h0 in handles_meta: uri = handles_meta[h0]["uri"] biz_label = infer_biz_label(class_name, uri) biz_key = infer_biz_key(uri, class_name) imports = set() imports.add("com.heyu.api.controller.AbstractRecognizeController") imports.add("com.heyu.api.data.utils.R") imports.add("org.springframework.web.bind.annotation.RequestMapping") imports.add("org.springframework.web.bind.annotation.RestController") imports.add("org.springframework.beans.factory.annotation.Autowired") imports.add("com.heyu.api.data.annotation.CacheResult") if info["has_slf4j"]: imports.add("lombok.extern.slf4j.Slf4j") handle_fields = [] handle_names = [] for handle_type, field_name in info["handles"]: handle_names.append((handle_type, field_name)) if not handle_names: return None resp_meta = handles_meta[handle_names[0][0]]["response"] use_map = resp_meta in ("Map", "JSONObject", "java.util.Map") resp_type = "java.util.Map" if use_map else resp_meta if not use_map and resp_type in ("Map", "JSONObject"): return None for handle_type, field_name in handle_names: imports.add(f"com.heyu.api.baidu.handle.{package_for_handle(handle_type)}") meta = handles_meta.get(handle_type, {}) if meta.get("request"): imports.add(request_import(meta["request"])) if meta.get("response") and not use_map: imports.add(response_import(meta["response"])) # extra imports from original if "ApiIdentityCardRequest" in info["text"]: imports.add("com.heyu.api.request.certificate.ApiIdentityCardRequest") if "BusinesslicenseVerificationDetailedReq" in info["text"]: imports.add("com.heyu.api.request.certificate.BusinesslicenseVerificationDetailedReq") if use_map: imports.add("java.util.Map") if "@NotIntercept" in info["text"]: imports.add("com.heyu.api.data.annotation.NotIntercept") if "@EbAuthentication" in info["text"]: imports.add("com.heyu.api.data.annotation.EbAuthentication") imports.add("com.heyu.api.data.constants.ApiConstants") if any(m.get("map_ann") == "PostMapping" for m in info["methods"]): imports.add("org.springframework.web.bind.annotation.PostMapping") methods_code = [] for method in info["methods"]: if len(handle_names) == 1: handle, field_name = handle_names[0] else: idx = len(methods_code) if idx >= len(handle_names): return None handle, field_name = handle_names[idx] if not handle or not field_name: return None mc = generate_method(method, handle, field_name, handles_meta, class_name, biz_label, biz_key) if not mc: return None methods_code.append(mc) slf4j = "@Slf4j\n" if info["has_slf4j"] else "" not_intercept = "@NotIntercept\n" if "@NotIntercept" in info["text"] else "" ann = info["annotations"] lines = [ f"package {info['pkg']};", "", ] for imp in sorted(imports): lines.append(f"import {imp};") lines.extend( [ "", slf4j.rstrip(), "@RestController", f'@RequestMapping("{info["base_mapping"]}")', not_intercept.rstrip(), f"public class {class_name} extends AbstractRecognizeController {{", "", f' private static final String SIDE_KEY = "{biz_key}";', "", ] ) for handle_type, field_name in handle_names: lines.append(" @Autowired") lines.append(f" private {handle_type} {field_name};") lines.append("") lines.extend(methods_code) lines.extend( [ "", " @Override", " protected Object defaultEmptyResp(String side) {", f" return {'new java.util.HashMap<>()' if use_map else 'new ' + resp_type + '()'};", " }", "", f" private String buildInputLog(Object request) {{", ' if (request == null) {', ' return "未收到请求体";', " }", ' return "已收到请求体,具体字段见业务入参";', " }", "}", "", ] ) return "\n".join(line for line in lines if line is not None) def package_for_handle(handle_type: str) -> str: rel = None for path in HANDLE_ROOT.rglob(f"{handle_type}.java"): rel = path.relative_to(HANDLE_ROOT).parent break if rel: return ".".join(rel.parts) + f".{handle_type}" return handle_type def request_import(req: str) -> str: for path in (ROOT / "api-third").rglob(f"{req}.java"): rel = path.relative_to(ROOT / "api-third/src/main/java") return ".".join(rel.with_suffix("").parts) return f"com.heyu.api.baidu.request.{req}" def response_import(resp: str) -> str: for path in (ROOT / "api-third").rglob(f"{resp}.java"): rel = path.relative_to(ROOT / "api-third/src/main/java") return ".".join(rel.with_suffix("").parts) return f"com.heyu.api.baidu.response.{resp}" def main(): handles_meta = load_handles() migrated = [] skipped = [] manual = [] for path in sorted(CTRL_ROOT.rglob("*Controller.java")): if path.name in SKIP: continue info = parse_controller(path) if info == "MANUAL" or info == "COMPLEX": manual.append(str(path)) continue if not info: continue code = generate_controller(info, handles_meta) if not code: skipped.append(str(path)) continue path.write_text(code, encoding="utf-8") migrated.append(str(path)) print(f"migrated: {len(migrated)}") for p in migrated: print(" OK", p.replace(str(ROOT) + "/", "")) print(f"skipped: {len(skipped)}") for p in skipped: print(" SKIP", p.replace(str(ROOT) + "/", "")) print(f"manual: {len(manual)}") for p in manual: print(" MANUAL", p.replace(str(ROOT) + "/", "")) if __name__ == "__main__": main()