eb-service-api/scripts/migrate_baidu_controllers.py
2026-06-24 01:38:07 +08:00

509 lines
17 KiB
Python

#!/usr/bin/env python3
"""Migrate Baidu-only controllers to AbstractRecognizeController pattern."""
import os
import re
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
CTRL_ROOT = ROOT / "api-web/api-interface/src/main/java/com/heyu/api/controller"
HANDLE_ROOT = ROOT / "api-third/src/main/java/com/heyu/api/baidu/handle"
SKIP = {
"BaseController.java",
"AbstractRecognizeController.java",
"TestController.java",
}
BIZ_NAMES = {
"bus_ticket": "汽车票识别",
"invoice": "通用机打发票识别",
"seal": "印章识别",
"idcard": "身份证识别",
"medical_outpatient": "门诊病历识别",
}
def load_handles():
handles = {}
for path in HANDLE_ROOT.rglob("*Handle.java"):
text = path.read_text(encoding="utf-8", errors="ignore")
m = re.search(
r"class\s+(\w+)\s+extends\s+BaiduBaseHandle<([^,>]+),\s*([^>]+)>",
text,
)
if not m:
m = re.search(
r"class\s+(\w+)\s+extends\s+(?:BEducationJsonHandle|BaiduFormRawHandle|BaiduMultipartRawHandle)<([^>]+)>",
text,
)
if m:
entry = {
"request": m.group(2).strip(),
"response": "java.util.Map",
"uri": None,
"json_body": True,
}
uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S)
if uri_m:
entry["uri"] = uri_m.group(1)
handles[m.group(1)] = entry
continue
if not m:
m = re.search(
r"class\s+(\w+)\s+extends\s+BEducationJsonHandle<([^>]+)>",
text,
)
if m:
handles[m.group(1)] = {
"request": m.group(2).strip(),
"response": "java.util.Map",
"uri": None,
"json_body": True,
}
uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S)
if uri_m:
handles[m.group(1)]["uri"] = uri_m.group(1)
continue
uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S)
if m and uri_m:
handles[m.group(1)] = {
"request": m.group(2).strip(),
"response": m.group(3).strip(),
"uri": uri_m.group(1),
"json_body": False,
}
return handles
def is_simple_method(body: str) -> bool:
if "for (" in body or "for(" in body:
return False
if "BaiduOcrResult.raw" in body:
return True
if ".handle(" in body and "return R.error" in body and "new ArrayList" not in body:
return True
if ".handle(" in body and "if (bR.isSuccess())" in body and body.count("return") <= 3:
return True
return False
def infer_biz_label(class_name, uri):
base = class_name.replace("Controller", "")
if base.startswith("Recognize"):
base = base[len("Recognize") :]
mapping = {
"BusTicket": "汽车票识别",
"Invoice": "通用机打发票识别",
"Seal": "印章识别",
"TaxiReceipt": "出租车票识别",
"AirTicket": "飞机行程单识别",
"BankReceiptNew": "银行回单识别",
"BankReceiptNewPro": "银行回单识别(专业版)",
"FerryTicket": "船票识别",
"ShoppingReceipt": "购物小票识别",
"TollInvoice": "过路过桥费发票识别",
"QuotaInvoice": "定额发票识别",
"MultipleInvoice": "混贴发票识别",
"OnlineTaxiItinerary": "网约车行程单识别",
"VatInvoiceVerification": "增值税发票核验",
"RecognizeVATInvoice": "增值税发票识别",
"RecognizeVINCode": "VIN码识别",
"RecognizeQrCode": "二维码识别",
"LicensePlate": "车牌识别",
"IdCard": "身份证识别",
"IdCardMulti": "身份证混贴识别",
"ForgeryDetection": "图片篡改检测",
}
if base in mapping:
return mapping[base]
# camel to words rough
words = re.sub(r"([a-z])([A-Z])", r"\1\2", base)
return words + "识别"
def infer_biz_key(uri, class_name):
if uri:
tail = uri.rstrip("/").split("/")[-1]
if tail:
return tail
return class_name.replace("Controller", "").lower()
def parse_controller(path: Path):
text = path.read_text(encoding="utf-8")
if "extends AbstractRecognizeController" in text:
return None
if "com.heyu.api.baidu" not in text:
return None
if "com.heyu.api.alibaba" in text:
return None
# skip manual complex for now
if path.name in ("MedicalOutpatientController.java",):
return "MANUAL"
pkg_m = re.search(r"package\s+([\w.]+);", text)
pkg = pkg_m.group(1) if pkg_m else ""
mapping_m = re.search(r"@RequestMapping\((\{[^}]+\}|\"[^\"]+\")", text)
if mapping_m:
base_mapping = mapping_m.group(1).strip().strip('"')
if base_mapping.startswith("{"):
# 取第一个路径作为类级 @RequestMapping
inner = re.search(r'"(/[^"]+)"', base_mapping)
base_mapping = inner.group(1) if inner else base_mapping
else:
base_mapping = ""
class_m = re.search(r"public class (\w+)", text)
class_name = class_m.group(1)
annotations = []
if "@NotIntercept" in text:
annotations.append("@NotIntercept")
if "@Slf4j" in text:
has_slf4j = True
else:
has_slf4j = False
methods = []
for m in re.finditer(
r"(@\w+(?:\([^)]*\))?\s+)*@(?P<mapAnn>RequestMapping|PostMapping)\((?P<mapping_raw>\{[^}]+\}|\"[^\"]+\")\)\s+"
r"((?:@\w+(?:\([^)]*\))?\s+)*)"
r"public\s+R(?:<[^>]+>)?\s+(?P<method_name>\w+)\((?P<params>[^)]*)\)\s*\{",
text,
re.S,
):
mapping_raw = m.group("mapping_raw")
map_ann = m.group("mapAnn")
method_name = m.group("method_name")
params = m.group("params").strip()
start = m.end()
brace = 1
i = start
while i < len(text) and brace:
if text[i] == "{":
brace += 1
elif text[i] == "}":
brace -= 1
i += 1
body = text[start : i - 1]
if not is_simple_method(body):
return "COMPLEX"
methods.append(
{
"mapping_raw": mapping_raw,
"map_ann": map_ann,
"method_name": method_name,
"params": params,
"body": body,
}
)
handle_injects = re.findall(
r"private\s+(\w+Handle)\s+(\w+);", text
)
return {
"path": path,
"pkg": pkg,
"class_name": class_name,
"base_mapping": base_mapping,
"annotations": annotations,
"has_slf4j": has_slf4j,
"methods": methods,
"handles": handle_injects,
"text": text,
}
def extract_handle_from_body(body, handles_meta):
m = re.search(r"([A-Z]\w*Handle)\.handle\(", body)
if m:
return m.group(1)
m = re.search(r"([A-Z]\w*Handle)\.(?:check|getContent)\(", body)
if m:
return m.group(1)
m = re.search(r"(\w+)\.handle\(", body)
if m:
field = m.group(1)
for handle_name in handles_meta:
bean = handle_name[0].lower() + handle_name[1:]
if bean == field:
return handle_name
return None
def generate_method(method, handle_name, field_name, handles_meta, class_name, biz_label, biz_key):
handle = handles_meta.get(handle_name)
if not handle:
return None
resp_type = handle["response"]
use_map = resp_type in ("Map", "JSONObject", "java.util.Map")
resp_type_java = "java.util.Map" if use_map else resp_type
if not use_map and resp_type_java.startswith("com."):
pass
elif not use_map and not resp_type_java[0].isupper():
return None
req_type = handle["request"]
params = method["params"]
param_name = params.split()[-1].replace(",", "").strip() if params else "request"
if not param_name or param_name == "void":
param_name = "request"
uri = handle["uri"]
side_key = biz_key
# detect request conversion in body
pre_code = ""
actual_param = param_name
body = method["body"]
if "new B" in body and "setImageUrl" in body:
# IdCard style mapping - keep original body setup
setup = []
for line in body.splitlines():
line = line.strip()
if line.startswith("B") and "Request" in line and "= new" in line:
setup.append(line)
elif line.startswith("b") and ".set" in line:
setup.append(line)
if "handle(" in line:
break
if setup:
b_req_var = setup[0].split()[0]
pre_code = "\n ".join(setup)
actual_param = b_req_var
req_type = b_req_var.split()[0] if "=" in setup[0] else req_type
if "new B" in body and "setVerifyNum" in body:
setup = []
for line in body.splitlines():
line = line.strip()
if ("Request" in line and "= new" in line) or (
line.startswith("bRequest.set") or line.startswith("bRequest.")
):
if "handle(" in line:
break
setup.append(line)
if setup:
pre_code = "\n ".join(setup)
actual_param = setup[0].split()[0]
method_mapping = method["mapping_raw"]
cache = "@CacheResult\n " if "@CacheResult" in body or True else ""
eb_auth = ""
if "@EbAuthentication" in body:
eb_auth = "@EbAuthentication(tencent = ApiConstants.TENCENT_AUTH)\n "
post_mapping = method_mapping.strip()
if post_mapping.startswith("{"):
post_mapping = post_mapping
else:
post_mapping = f'"{post_mapping.strip("\"")}"'
pre_block = f"{pre_code}\n " if pre_code else ""
map_ann = method.get("map_ann", "RequestMapping")
if map_ann == "PostMapping":
imports_post = True
return f""" @{map_ann}({post_mapping})
{cache}public R {method['method_name']}({params}) {{
{pre_block}return executeLegacyBaiduRecognize(LegacyBaiduRecognizeSpec.<{resp_type_java}>builder()
.bizLabel("{biz_label}")
.bizKey("{side_key}")
.uri("{uri}")
.sideKey("{side_key}")
.targetLabel("{biz_label.replace('识别', '')}")
.inputLog(buildInputLog({actual_param}))
.respClass({resp_type_java}.class)
.paramCheck(() -> {field_name}.check({actual_param}))
.requestContent(() -> {field_name}.getContent({actual_param}))
.build());
}}"""
def generate_controller(info, handles_meta):
class_name = info["class_name"]
uri = ""
if info["handles"]:
h0 = info["handles"][0][0]
if h0 in handles_meta:
uri = handles_meta[h0]["uri"]
biz_label = infer_biz_label(class_name, uri)
biz_key = infer_biz_key(uri, class_name)
imports = set()
imports.add("com.heyu.api.controller.AbstractRecognizeController")
imports.add("com.heyu.api.data.utils.R")
imports.add("org.springframework.web.bind.annotation.RequestMapping")
imports.add("org.springframework.web.bind.annotation.RestController")
imports.add("org.springframework.beans.factory.annotation.Autowired")
imports.add("com.heyu.api.data.annotation.CacheResult")
if info["has_slf4j"]:
imports.add("lombok.extern.slf4j.Slf4j")
handle_fields = []
handle_names = []
for handle_type, field_name in info["handles"]:
handle_names.append((handle_type, field_name))
if not handle_names:
return None
resp_meta = handles_meta[handle_names[0][0]]["response"]
use_map = resp_meta in ("Map", "JSONObject", "java.util.Map")
resp_type = "java.util.Map" if use_map else resp_meta
if not use_map and resp_type in ("Map", "JSONObject"):
return None
for handle_type, field_name in handle_names:
imports.add(f"com.heyu.api.baidu.handle.{package_for_handle(handle_type)}")
meta = handles_meta.get(handle_type, {})
if meta.get("request"):
imports.add(request_import(meta["request"]))
if meta.get("response") and not use_map:
imports.add(response_import(meta["response"]))
# extra imports from original
if "ApiIdentityCardRequest" in info["text"]:
imports.add("com.heyu.api.request.certificate.ApiIdentityCardRequest")
if "BusinesslicenseVerificationDetailedReq" in info["text"]:
imports.add("com.heyu.api.request.certificate.BusinesslicenseVerificationDetailedReq")
if use_map:
imports.add("java.util.Map")
if "@NotIntercept" in info["text"]:
imports.add("com.heyu.api.data.annotation.NotIntercept")
if "@EbAuthentication" in info["text"]:
imports.add("com.heyu.api.data.annotation.EbAuthentication")
imports.add("com.heyu.api.data.constants.ApiConstants")
if any(m.get("map_ann") == "PostMapping" for m in info["methods"]):
imports.add("org.springframework.web.bind.annotation.PostMapping")
methods_code = []
for method in info["methods"]:
if len(handle_names) == 1:
handle, field_name = handle_names[0]
else:
idx = len(methods_code)
if idx >= len(handle_names):
return None
handle, field_name = handle_names[idx]
if not handle or not field_name:
return None
mc = generate_method(method, handle, field_name, handles_meta, class_name, biz_label, biz_key)
if not mc:
return None
methods_code.append(mc)
slf4j = "@Slf4j\n" if info["has_slf4j"] else ""
not_intercept = "@NotIntercept\n" if "@NotIntercept" in info["text"] else ""
ann = info["annotations"]
lines = [
f"package {info['pkg']};",
"",
]
for imp in sorted(imports):
lines.append(f"import {imp};")
lines.extend(
[
"",
slf4j.rstrip(),
"@RestController",
f'@RequestMapping("{info["base_mapping"]}")',
not_intercept.rstrip(),
f"public class {class_name} extends AbstractRecognizeController {{",
"",
f' private static final String SIDE_KEY = "{biz_key}";',
"",
]
)
for handle_type, field_name in handle_names:
lines.append(" @Autowired")
lines.append(f" private {handle_type} {field_name};")
lines.append("")
lines.extend(methods_code)
lines.extend(
[
"",
" @Override",
" protected Object defaultEmptyResp(String side) {",
f" return {'new java.util.HashMap<>()' if use_map else 'new ' + resp_type + '()'};",
" }",
"",
f" private String buildInputLog(Object request) {{",
' if (request == null) {',
' return "未收到请求体";',
" }",
' return "已收到请求体,具体字段见业务入参";',
" }",
"}",
"",
]
)
return "\n".join(line for line in lines if line is not None)
def package_for_handle(handle_type: str) -> str:
rel = None
for path in HANDLE_ROOT.rglob(f"{handle_type}.java"):
rel = path.relative_to(HANDLE_ROOT).parent
break
if rel:
return ".".join(rel.parts) + f".{handle_type}"
return handle_type
def request_import(req: str) -> str:
for path in (ROOT / "api-third").rglob(f"{req}.java"):
rel = path.relative_to(ROOT / "api-third/src/main/java")
return ".".join(rel.with_suffix("").parts)
return f"com.heyu.api.baidu.request.{req}"
def response_import(resp: str) -> str:
for path in (ROOT / "api-third").rglob(f"{resp}.java"):
rel = path.relative_to(ROOT / "api-third/src/main/java")
return ".".join(rel.with_suffix("").parts)
return f"com.heyu.api.baidu.response.{resp}"
def main():
handles_meta = load_handles()
migrated = []
skipped = []
manual = []
for path in sorted(CTRL_ROOT.rglob("*Controller.java")):
if path.name in SKIP:
continue
info = parse_controller(path)
if info == "MANUAL" or info == "COMPLEX":
manual.append(str(path))
continue
if not info:
continue
code = generate_controller(info, handles_meta)
if not code:
skipped.append(str(path))
continue
path.write_text(code, encoding="utf-8")
migrated.append(str(path))
print(f"migrated: {len(migrated)}")
for p in migrated:
print(" OK", p.replace(str(ROOT) + "/", ""))
print(f"skipped: {len(skipped)}")
for p in skipped:
print(" SKIP", p.replace(str(ROOT) + "/", ""))
print(f"manual: {len(manual)}")
for p in manual:
print(" MANUAL", p.replace(str(ROOT) + "/", ""))
if __name__ == "__main__":
main()