509 lines
17 KiB
Python
509 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""Migrate Baidu-only controllers to AbstractRecognizeController pattern."""
|
|
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
CTRL_ROOT = ROOT / "api-web/api-interface/src/main/java/com/heyu/api/controller"
|
|
HANDLE_ROOT = ROOT / "api-third/src/main/java/com/heyu/api/baidu/handle"
|
|
|
|
SKIP = {
|
|
"BaseController.java",
|
|
"AbstractRecognizeController.java",
|
|
"TestController.java",
|
|
}
|
|
|
|
BIZ_NAMES = {
|
|
"bus_ticket": "汽车票识别",
|
|
"invoice": "通用机打发票识别",
|
|
"seal": "印章识别",
|
|
"idcard": "身份证识别",
|
|
"medical_outpatient": "门诊病历识别",
|
|
}
|
|
|
|
def load_handles():
|
|
handles = {}
|
|
for path in HANDLE_ROOT.rglob("*Handle.java"):
|
|
text = path.read_text(encoding="utf-8", errors="ignore")
|
|
m = re.search(
|
|
r"class\s+(\w+)\s+extends\s+BaiduBaseHandle<([^,>]+),\s*([^>]+)>",
|
|
text,
|
|
)
|
|
if not m:
|
|
m = re.search(
|
|
r"class\s+(\w+)\s+extends\s+(?:BEducationJsonHandle|BaiduFormRawHandle|BaiduMultipartRawHandle)<([^>]+)>",
|
|
text,
|
|
)
|
|
if m:
|
|
entry = {
|
|
"request": m.group(2).strip(),
|
|
"response": "java.util.Map",
|
|
"uri": None,
|
|
"json_body": True,
|
|
}
|
|
uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S)
|
|
if uri_m:
|
|
entry["uri"] = uri_m.group(1)
|
|
handles[m.group(1)] = entry
|
|
continue
|
|
if not m:
|
|
m = re.search(
|
|
r"class\s+(\w+)\s+extends\s+BEducationJsonHandle<([^>]+)>",
|
|
text,
|
|
)
|
|
if m:
|
|
handles[m.group(1)] = {
|
|
"request": m.group(2).strip(),
|
|
"response": "java.util.Map",
|
|
"uri": None,
|
|
"json_body": True,
|
|
}
|
|
uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S)
|
|
if uri_m:
|
|
handles[m.group(1)]["uri"] = uri_m.group(1)
|
|
continue
|
|
uri_m = re.search(r'getUri\(\)\s*\{[^}]*return\s+"([^"]+)"', text, re.S)
|
|
if m and uri_m:
|
|
handles[m.group(1)] = {
|
|
"request": m.group(2).strip(),
|
|
"response": m.group(3).strip(),
|
|
"uri": uri_m.group(1),
|
|
"json_body": False,
|
|
}
|
|
return handles
|
|
|
|
|
|
def is_simple_method(body: str) -> bool:
|
|
if "for (" in body or "for(" in body:
|
|
return False
|
|
if "BaiduOcrResult.raw" in body:
|
|
return True
|
|
if ".handle(" in body and "return R.error" in body and "new ArrayList" not in body:
|
|
return True
|
|
if ".handle(" in body and "if (bR.isSuccess())" in body and body.count("return") <= 3:
|
|
return True
|
|
return False
|
|
|
|
|
|
def infer_biz_label(class_name, uri):
|
|
base = class_name.replace("Controller", "")
|
|
if base.startswith("Recognize"):
|
|
base = base[len("Recognize") :]
|
|
mapping = {
|
|
"BusTicket": "汽车票识别",
|
|
"Invoice": "通用机打发票识别",
|
|
"Seal": "印章识别",
|
|
"TaxiReceipt": "出租车票识别",
|
|
"AirTicket": "飞机行程单识别",
|
|
"BankReceiptNew": "银行回单识别",
|
|
"BankReceiptNewPro": "银行回单识别(专业版)",
|
|
"FerryTicket": "船票识别",
|
|
"ShoppingReceipt": "购物小票识别",
|
|
"TollInvoice": "过路过桥费发票识别",
|
|
"QuotaInvoice": "定额发票识别",
|
|
"MultipleInvoice": "混贴发票识别",
|
|
"OnlineTaxiItinerary": "网约车行程单识别",
|
|
"VatInvoiceVerification": "增值税发票核验",
|
|
"RecognizeVATInvoice": "增值税发票识别",
|
|
"RecognizeVINCode": "VIN码识别",
|
|
"RecognizeQrCode": "二维码识别",
|
|
"LicensePlate": "车牌识别",
|
|
"IdCard": "身份证识别",
|
|
"IdCardMulti": "身份证混贴识别",
|
|
"ForgeryDetection": "图片篡改检测",
|
|
}
|
|
if base in mapping:
|
|
return mapping[base]
|
|
# camel to words rough
|
|
words = re.sub(r"([a-z])([A-Z])", r"\1\2", base)
|
|
return words + "识别"
|
|
|
|
|
|
def infer_biz_key(uri, class_name):
|
|
if uri:
|
|
tail = uri.rstrip("/").split("/")[-1]
|
|
if tail:
|
|
return tail
|
|
return class_name.replace("Controller", "").lower()
|
|
|
|
|
|
def parse_controller(path: Path):
|
|
text = path.read_text(encoding="utf-8")
|
|
if "extends AbstractRecognizeController" in text:
|
|
return None
|
|
if "com.heyu.api.baidu" not in text:
|
|
return None
|
|
if "com.heyu.api.alibaba" in text:
|
|
return None
|
|
|
|
# skip manual complex for now
|
|
if path.name in ("MedicalOutpatientController.java",):
|
|
return "MANUAL"
|
|
|
|
pkg_m = re.search(r"package\s+([\w.]+);", text)
|
|
pkg = pkg_m.group(1) if pkg_m else ""
|
|
|
|
mapping_m = re.search(r"@RequestMapping\((\{[^}]+\}|\"[^\"]+\")", text)
|
|
if mapping_m:
|
|
base_mapping = mapping_m.group(1).strip().strip('"')
|
|
if base_mapping.startswith("{"):
|
|
# 取第一个路径作为类级 @RequestMapping
|
|
inner = re.search(r'"(/[^"]+)"', base_mapping)
|
|
base_mapping = inner.group(1) if inner else base_mapping
|
|
else:
|
|
base_mapping = ""
|
|
|
|
class_m = re.search(r"public class (\w+)", text)
|
|
class_name = class_m.group(1)
|
|
|
|
annotations = []
|
|
if "@NotIntercept" in text:
|
|
annotations.append("@NotIntercept")
|
|
if "@Slf4j" in text:
|
|
has_slf4j = True
|
|
else:
|
|
has_slf4j = False
|
|
|
|
methods = []
|
|
for m in re.finditer(
|
|
r"(@\w+(?:\([^)]*\))?\s+)*@(?P<mapAnn>RequestMapping|PostMapping)\((?P<mapping_raw>\{[^}]+\}|\"[^\"]+\")\)\s+"
|
|
r"((?:@\w+(?:\([^)]*\))?\s+)*)"
|
|
r"public\s+R(?:<[^>]+>)?\s+(?P<method_name>\w+)\((?P<params>[^)]*)\)\s*\{",
|
|
text,
|
|
re.S,
|
|
):
|
|
mapping_raw = m.group("mapping_raw")
|
|
map_ann = m.group("mapAnn")
|
|
method_name = m.group("method_name")
|
|
params = m.group("params").strip()
|
|
start = m.end()
|
|
brace = 1
|
|
i = start
|
|
while i < len(text) and brace:
|
|
if text[i] == "{":
|
|
brace += 1
|
|
elif text[i] == "}":
|
|
brace -= 1
|
|
i += 1
|
|
body = text[start : i - 1]
|
|
if not is_simple_method(body):
|
|
return "COMPLEX"
|
|
methods.append(
|
|
{
|
|
"mapping_raw": mapping_raw,
|
|
"map_ann": map_ann,
|
|
"method_name": method_name,
|
|
"params": params,
|
|
"body": body,
|
|
}
|
|
)
|
|
|
|
handle_injects = re.findall(
|
|
r"private\s+(\w+Handle)\s+(\w+);", text
|
|
)
|
|
return {
|
|
"path": path,
|
|
"pkg": pkg,
|
|
"class_name": class_name,
|
|
"base_mapping": base_mapping,
|
|
"annotations": annotations,
|
|
"has_slf4j": has_slf4j,
|
|
"methods": methods,
|
|
"handles": handle_injects,
|
|
"text": text,
|
|
}
|
|
|
|
|
|
def extract_handle_from_body(body, handles_meta):
|
|
m = re.search(r"([A-Z]\w*Handle)\.handle\(", body)
|
|
if m:
|
|
return m.group(1)
|
|
m = re.search(r"([A-Z]\w*Handle)\.(?:check|getContent)\(", body)
|
|
if m:
|
|
return m.group(1)
|
|
m = re.search(r"(\w+)\.handle\(", body)
|
|
if m:
|
|
field = m.group(1)
|
|
for handle_name in handles_meta:
|
|
bean = handle_name[0].lower() + handle_name[1:]
|
|
if bean == field:
|
|
return handle_name
|
|
return None
|
|
|
|
|
|
def generate_method(method, handle_name, field_name, handles_meta, class_name, biz_label, biz_key):
|
|
handle = handles_meta.get(handle_name)
|
|
if not handle:
|
|
return None
|
|
|
|
resp_type = handle["response"]
|
|
use_map = resp_type in ("Map", "JSONObject", "java.util.Map")
|
|
resp_type_java = "java.util.Map" if use_map else resp_type
|
|
if not use_map and resp_type_java.startswith("com."):
|
|
pass
|
|
elif not use_map and not resp_type_java[0].isupper():
|
|
return None
|
|
|
|
req_type = handle["request"]
|
|
params = method["params"]
|
|
param_name = params.split()[-1].replace(",", "").strip() if params else "request"
|
|
if not param_name or param_name == "void":
|
|
param_name = "request"
|
|
|
|
uri = handle["uri"]
|
|
side_key = biz_key
|
|
|
|
# detect request conversion in body
|
|
pre_code = ""
|
|
actual_param = param_name
|
|
body = method["body"]
|
|
if "new B" in body and "setImageUrl" in body:
|
|
# IdCard style mapping - keep original body setup
|
|
setup = []
|
|
for line in body.splitlines():
|
|
line = line.strip()
|
|
if line.startswith("B") and "Request" in line and "= new" in line:
|
|
setup.append(line)
|
|
elif line.startswith("b") and ".set" in line:
|
|
setup.append(line)
|
|
if "handle(" in line:
|
|
break
|
|
if setup:
|
|
b_req_var = setup[0].split()[0]
|
|
pre_code = "\n ".join(setup)
|
|
actual_param = b_req_var
|
|
req_type = b_req_var.split()[0] if "=" in setup[0] else req_type
|
|
|
|
if "new B" in body and "setVerifyNum" in body:
|
|
setup = []
|
|
for line in body.splitlines():
|
|
line = line.strip()
|
|
if ("Request" in line and "= new" in line) or (
|
|
line.startswith("bRequest.set") or line.startswith("bRequest.")
|
|
):
|
|
if "handle(" in line:
|
|
break
|
|
setup.append(line)
|
|
if setup:
|
|
pre_code = "\n ".join(setup)
|
|
actual_param = setup[0].split()[0]
|
|
|
|
method_mapping = method["mapping_raw"]
|
|
cache = "@CacheResult\n " if "@CacheResult" in body or True else ""
|
|
eb_auth = ""
|
|
if "@EbAuthentication" in body:
|
|
eb_auth = "@EbAuthentication(tencent = ApiConstants.TENCENT_AUTH)\n "
|
|
|
|
post_mapping = method_mapping.strip()
|
|
if post_mapping.startswith("{"):
|
|
post_mapping = post_mapping
|
|
else:
|
|
post_mapping = f'"{post_mapping.strip("\"")}"'
|
|
|
|
pre_block = f"{pre_code}\n " if pre_code else ""
|
|
|
|
map_ann = method.get("map_ann", "RequestMapping")
|
|
if map_ann == "PostMapping":
|
|
imports_post = True
|
|
return f""" @{map_ann}({post_mapping})
|
|
{cache}public R {method['method_name']}({params}) {{
|
|
{pre_block}return executeLegacyBaiduRecognize(LegacyBaiduRecognizeSpec.<{resp_type_java}>builder()
|
|
.bizLabel("{biz_label}")
|
|
.bizKey("{side_key}")
|
|
.uri("{uri}")
|
|
.sideKey("{side_key}")
|
|
.targetLabel("{biz_label.replace('识别', '')}")
|
|
.inputLog(buildInputLog({actual_param}))
|
|
.respClass({resp_type_java}.class)
|
|
.paramCheck(() -> {field_name}.check({actual_param}))
|
|
.requestContent(() -> {field_name}.getContent({actual_param}))
|
|
.build());
|
|
}}"""
|
|
|
|
|
|
def generate_controller(info, handles_meta):
|
|
class_name = info["class_name"]
|
|
uri = ""
|
|
if info["handles"]:
|
|
h0 = info["handles"][0][0]
|
|
if h0 in handles_meta:
|
|
uri = handles_meta[h0]["uri"]
|
|
biz_label = infer_biz_label(class_name, uri)
|
|
biz_key = infer_biz_key(uri, class_name)
|
|
|
|
imports = set()
|
|
imports.add("com.heyu.api.controller.AbstractRecognizeController")
|
|
imports.add("com.heyu.api.data.utils.R")
|
|
imports.add("org.springframework.web.bind.annotation.RequestMapping")
|
|
imports.add("org.springframework.web.bind.annotation.RestController")
|
|
imports.add("org.springframework.beans.factory.annotation.Autowired")
|
|
imports.add("com.heyu.api.data.annotation.CacheResult")
|
|
|
|
if info["has_slf4j"]:
|
|
imports.add("lombok.extern.slf4j.Slf4j")
|
|
|
|
handle_fields = []
|
|
handle_names = []
|
|
for handle_type, field_name in info["handles"]:
|
|
handle_names.append((handle_type, field_name))
|
|
|
|
if not handle_names:
|
|
return None
|
|
|
|
resp_meta = handles_meta[handle_names[0][0]]["response"]
|
|
use_map = resp_meta in ("Map", "JSONObject", "java.util.Map")
|
|
resp_type = "java.util.Map" if use_map else resp_meta
|
|
if not use_map and resp_type in ("Map", "JSONObject"):
|
|
return None
|
|
|
|
for handle_type, field_name in handle_names:
|
|
imports.add(f"com.heyu.api.baidu.handle.{package_for_handle(handle_type)}")
|
|
meta = handles_meta.get(handle_type, {})
|
|
if meta.get("request"):
|
|
imports.add(request_import(meta["request"]))
|
|
if meta.get("response") and not use_map:
|
|
imports.add(response_import(meta["response"]))
|
|
|
|
# extra imports from original
|
|
if "ApiIdentityCardRequest" in info["text"]:
|
|
imports.add("com.heyu.api.request.certificate.ApiIdentityCardRequest")
|
|
if "BusinesslicenseVerificationDetailedReq" in info["text"]:
|
|
imports.add("com.heyu.api.request.certificate.BusinesslicenseVerificationDetailedReq")
|
|
if use_map:
|
|
imports.add("java.util.Map")
|
|
if "@NotIntercept" in info["text"]:
|
|
imports.add("com.heyu.api.data.annotation.NotIntercept")
|
|
if "@EbAuthentication" in info["text"]:
|
|
imports.add("com.heyu.api.data.annotation.EbAuthentication")
|
|
imports.add("com.heyu.api.data.constants.ApiConstants")
|
|
|
|
if any(m.get("map_ann") == "PostMapping" for m in info["methods"]):
|
|
imports.add("org.springframework.web.bind.annotation.PostMapping")
|
|
|
|
methods_code = []
|
|
for method in info["methods"]:
|
|
if len(handle_names) == 1:
|
|
handle, field_name = handle_names[0]
|
|
else:
|
|
idx = len(methods_code)
|
|
if idx >= len(handle_names):
|
|
return None
|
|
handle, field_name = handle_names[idx]
|
|
if not handle or not field_name:
|
|
return None
|
|
mc = generate_method(method, handle, field_name, handles_meta, class_name, biz_label, biz_key)
|
|
if not mc:
|
|
return None
|
|
methods_code.append(mc)
|
|
|
|
slf4j = "@Slf4j\n" if info["has_slf4j"] else ""
|
|
not_intercept = "@NotIntercept\n" if "@NotIntercept" in info["text"] else ""
|
|
ann = info["annotations"]
|
|
|
|
lines = [
|
|
f"package {info['pkg']};",
|
|
"",
|
|
]
|
|
for imp in sorted(imports):
|
|
lines.append(f"import {imp};")
|
|
lines.extend(
|
|
[
|
|
"",
|
|
slf4j.rstrip(),
|
|
"@RestController",
|
|
f'@RequestMapping("{info["base_mapping"]}")',
|
|
not_intercept.rstrip(),
|
|
f"public class {class_name} extends AbstractRecognizeController {{",
|
|
"",
|
|
f' private static final String SIDE_KEY = "{biz_key}";',
|
|
"",
|
|
]
|
|
)
|
|
for handle_type, field_name in handle_names:
|
|
lines.append(" @Autowired")
|
|
lines.append(f" private {handle_type} {field_name};")
|
|
lines.append("")
|
|
|
|
lines.extend(methods_code)
|
|
lines.extend(
|
|
[
|
|
"",
|
|
" @Override",
|
|
" protected Object defaultEmptyResp(String side) {",
|
|
f" return {'new java.util.HashMap<>()' if use_map else 'new ' + resp_type + '()'};",
|
|
" }",
|
|
"",
|
|
f" private String buildInputLog(Object request) {{",
|
|
' if (request == null) {',
|
|
' return "未收到请求体";',
|
|
" }",
|
|
' return "已收到请求体,具体字段见业务入参";',
|
|
" }",
|
|
"}",
|
|
"",
|
|
]
|
|
)
|
|
return "\n".join(line for line in lines if line is not None)
|
|
|
|
|
|
def package_for_handle(handle_type: str) -> str:
|
|
rel = None
|
|
for path in HANDLE_ROOT.rglob(f"{handle_type}.java"):
|
|
rel = path.relative_to(HANDLE_ROOT).parent
|
|
break
|
|
if rel:
|
|
return ".".join(rel.parts) + f".{handle_type}"
|
|
return handle_type
|
|
|
|
|
|
def request_import(req: str) -> str:
|
|
for path in (ROOT / "api-third").rglob(f"{req}.java"):
|
|
rel = path.relative_to(ROOT / "api-third/src/main/java")
|
|
return ".".join(rel.with_suffix("").parts)
|
|
return f"com.heyu.api.baidu.request.{req}"
|
|
|
|
|
|
def response_import(resp: str) -> str:
|
|
for path in (ROOT / "api-third").rglob(f"{resp}.java"):
|
|
rel = path.relative_to(ROOT / "api-third/src/main/java")
|
|
return ".".join(rel.with_suffix("").parts)
|
|
return f"com.heyu.api.baidu.response.{resp}"
|
|
|
|
|
|
def main():
|
|
handles_meta = load_handles()
|
|
migrated = []
|
|
skipped = []
|
|
manual = []
|
|
for path in sorted(CTRL_ROOT.rglob("*Controller.java")):
|
|
if path.name in SKIP:
|
|
continue
|
|
info = parse_controller(path)
|
|
if info == "MANUAL" or info == "COMPLEX":
|
|
manual.append(str(path))
|
|
continue
|
|
if not info:
|
|
continue
|
|
code = generate_controller(info, handles_meta)
|
|
if not code:
|
|
skipped.append(str(path))
|
|
continue
|
|
path.write_text(code, encoding="utf-8")
|
|
migrated.append(str(path))
|
|
|
|
print(f"migrated: {len(migrated)}")
|
|
for p in migrated:
|
|
print(" OK", p.replace(str(ROOT) + "/", ""))
|
|
print(f"skipped: {len(skipped)}")
|
|
for p in skipped:
|
|
print(" SKIP", p.replace(str(ROOT) + "/", ""))
|
|
print(f"manual: {len(manual)}")
|
|
for p in manual:
|
|
print(" MANUAL", p.replace(str(ROOT) + "/", ""))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|