Kai Jia

PhD student

MIT CSAIL

Hello! I aim to create more reliable artificial intelligence systems, drawing on perspectives from programming languages and symbolic reasoning. My research interests span across computer science, from computer systems (architecture, high-performance computing, and software engineering) to the computational modeling of complex phenomena in areas such as computer graphics and algorithmic game theory. I thrive on developing mathematical and computational tools to solve challenging problems and have a particular passion for designing and coding large-scale, agile software systems tailored for intricate tasks.

I currently work at Ecopia AI. I earned my PhD in computer science at MIT under the mentorship of Professor Martin Rinard and I hold a B.E. in computer science from Tsinghua University. Before my PhD journey, I worked at a startup company where I led the development of an in-house deep learning system, which was later open sourced as MegEngine.

@misc{jia2024limited, title={Limited-perception games}, author={Kai Jia and Martin Rinard}, year={2024}, eprint={2405.16735}, archivePrefix={arXiv}, primaryClass={cs.GT}, url={https://arxiv.org/abs/2405.16735} }

@inproceedings{yang2022on, author="Yang, Yichen and Jia, Kai and Rinard, Martin", editor="Kanellopoulos, Panagiotis and Kyropoulou, Maria and Voudouris, Alexandros", title="On the Impact of Player Capability on Congestion Games", booktitle="Algorithmic Game Theory", year="2022", publisher="Springer International Publishing", address="Cham", pages="311--328", isbn="978-3-031-15714-1" }

Motivated by the need to reliably characterize the robustness of deep neural networks, researchers have developed verification algorithms for deep neural networks. Given a neural network, the verifiers aim to answer whether certain properties are guaranteed with respect to all inputs in a space. However, little attention has been paid to floating point numerical error in neural network verification.

We show that the negligence of floating point error is easily exploitable in practice. For a pretrained neural network, we present a method that efficiently searches inputs regarding which a complete verifier incorrectly claims the network is robust. We also present a method to construct neural network architectures and weights that induce wrong results of an incomplete verifier. Our results highlight that, to achieve practically reliable verification of neural networks, any verification system must accurately (or conservatively) model the effects of any floating point computations in the network inference or verification system.

@inproceedings{jia2021exploiting, author="Jia, Kai and Rinard, Martin", editor="Dr{\u{a}}goi, Cezara and Mukherjee, Suvam and Namjoshi, Kedar", title="Exploiting Verified Neural Networks via Floating Point Numerical Error", booktitle="Static Analysis", year="2021", publisher="Springer International Publishing", address="Cham", pages="191--205", isbn="978-3-030-88806-0" }

Deep neural networks are an attractive tool for compressing the control policy lookup tables in systems such as the Airborne Collision Avoidance System (ACAS). It is vital to ensure the safety of such neural controllers via verification techniques. The problem of analyzing ACAS Xu networks has motivated many successful neural network verifiers. These verifiers typically analyze the internal computation of neural networks to decide whether a property regarding the input/output holds. The intrinsic complexity of neural network computation renders such verifiers slow to run and vulnerable to floating-point error.

This paper revisits the original problem of verifying ACAS Xu networks. The networks take low-dimensional sensory inputs with training data extracted from a lookup table. We propose to prepend an input quantization layer to the network. Quantization allows efficient verification via input state enumeration, whose complexity is bounded by the size of the quantization space. Quantization is equivalent to nearest-neighbor interpolation at run time, which has been shown to provide acceptable accuracy for ACAS in simulation. Moreover, our technique can deliver exact verification results immune to floating-point error if we directly enumerate the network outputs on the target inference implementation or on an accurate simulation of the target implementation.

@inproceedings{jia2021verifying, author="Jia, Kai and Rinard, Martin", editor="Dr{\u{a}}goi, Cezara and Mukherjee, Suvam and Namjoshi, Kedar", title="Verifying Low-Dimensional Input Neural Networks via Input Quantization", booktitle="Static Analysis", year="2021", publisher="Springer International Publishing", address="Cham", pages="206--214", isbn="978-3-030-88806-0" }

@inproceedings{jia2020efficient, author = {Jia, Kai and Rinard, Martin}, booktitle = {Advances in Neural Information Processing Systems}, editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, pages = {1782--1795}, publisher = {Curran Associates, Inc.}, title = {Efficient Exact Verification of Binarized Neural Networks}, url = {https://proceedings.neurips.cc/paper/2020/file/1385974ed5904a438616ff7bdb3f7439-Paper.pdf}, volume = {33}, year = {2020} }

@misc{jia2024trafs, title={TRAFS: A Nonsmooth Convex Optimization Algorithm with $\mathcal{O}\left(\frac{1}{\epsilon}\right)$ Iteration Complexity}, author={Kai Jia and Martin Rinard}, year={2024}, eprint={2311.06205}, archivePrefix={arXiv}, primaryClass={math.OC}, url={https://arxiv.org/abs/2311.06205} }

Solving nonlinear systems is an important problem. Numerical continuation methods efficiently solve certain nonlinear systems. The Asymptotic Numerical Method (ANM) is a powerful continuation method that usually converges faster than Newtonian methods. ANM explores the landscape of the function by following a parameterized solution curve approximated with a high-order power series. Although ANM has successfully solved a few graphics and engineering problems, prior to our work, applying ANM to new problems required significant effort because the standard ANM assumes quadratic functions, while manually deriving the power series expansion for nonquadratic systems is a tedious and challenging task.

This paper presents a novel solver, SANM, that applies ANM to solve symbolically represented nonlinear systems. SANM solves such systems in a fully automated manner. SANM also extends ANM to support many nonquadratic operators, including intricate ones such as singular value decomposition. Furthermore, SANM generalizes ANM to support the implicit homotopy form. Moreover, SANM achieves high computing performance via optimized system design and implementation.

We deploy SANM to solve forward and inverse elastic force equilibrium problems and controlled mesh deformation problems with a few constitutive models. Our results show that SANM converges faster than Newtonian solvers, requires little programming effort for new problems, and delivers comparable or better performance than a hand-coded, specialized ANM solver. While we demonstrate on mesh deformation problems, SANM is generic and potentially applicable to many tasks.

@article{jia2021sanm, title={{SANM}: A Symbolic Asymptotic Numerical Solver with Applications in Mesh Deformation}, author={Jia, Kai}, journal={{ACM} Transactions on Graphics (Proc. {SIGGRAPH})}, publisher={ACM}, year={2021}, volume={40}, number={4} }

Background

Pancreatic Duct Adenocarcinoma (PDAC) screening can enable early-stage disease detection and long-term survival. Current guidelines use inherited predisposition, with about 10% of PDAC cases eligible for screening. Using Electronic Health Record (EHR) data from a multi-institutional federated network, we developed and validated a PDAC RISk Model (Prism) for the general US population to extend early PDAC detection.

Methods

Neural Network (PrismNN) and Logistic Regression (PrismLR) were developed using EHR data from 55 US Health Care Organisations (HCOs) to predict PDAC risk 6–18 months before diagnosis for patients 40 years or older. Model performance was assessed using Area Under the Curve (AUC) and calibration plots. Models were internal-externally validated by geographic location, race, and time. Simulated model deployment evaluated Standardised Incidence Ratio (SIR) and other metrics.

Findings

With 35,387 PDAC cases, 1,500,081 controls, and 87 features per patient, PrismNN obtained a test AUC of 0.826 (95% CI: 0.824–0.828) (PrismLR: 0.800 (95% CI: 0.798–0.802)). PrismNN's average internal-external validation AUCs were 0.740 for locations, 0.828 for races, and 0.789 (95% CI: 0.762–0.816) for time. At SIR = 5.10 (exceeding the current screening inclusion threshold) in simulated model deployment, PrismNN sensitivity was 35.9% (specificity 95.3%).

Interpretation

Prism models demonstrated good accuracy and generalizability across diverse populations. PrismNN could find 3.5 times more cases at comparable risk than current screening guidelines. The small number of features provided a basis for model interpretation. Integration with the federated network provided data from a large, heterogeneous patient population and a pathway to future clinical deployment.

Funding

Prevent Cancer Foundation, TriNetX, Boeing, DARPA, NSF, and Aarno Labs.

@article{jia2023pancreatic, title={A pancreatic cancer risk prediction model (Prism) developed and validated on large-scale US clinical data}, author={Jia, Kai and Kundrot, Steven and Palchuk, Matvey B and Warnick, Jeff and Haapala, Kathryn and Kaplan, Irving D and Rinard, Martin and Appelbaum, Limor}, journal={EBioMedicine}, volume={98}, year={2023}, publisher={Elsevier} }

@misc{jia2023effective, title={Effective Neural Network $L_0$ Regularization With BinMask}, author={Kai Jia and Martin Rinard}, year={2023}, eprint={2304.11237}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2304.11237} }

@misc{jia2023sound, title={Sound Explanation for Trustworthy Machine Learning}, author={Kai Jia and Pasapol Saowakon and Limor Appelbaum and Martin Rinard}, year={2023}, eprint={2306.06134}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2306.06134} }