Skip to content

Commit b08d9ba

Browse files
Update design for PR xai-org#315: added jaxlib specific version
1 parent 4f45f05 commit b08d9ba

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

.exp/design-workflow-1-grok-1-inference-and-sampling.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ sequenceDiagram
6262
Checkpoint->>MR: Sharded params (TrainingState)
6363
IR->>IR: Load tokenizer, compile pjit funcs (sample_step, prefill_memory, new_memory) with shardings
6464
IR->>IR: Precompile with dummy prompts for pad_sizes
65-
RunPy->>IR: gen = run() // generator setup with initial memory, settings, etc.
65+
RunPy->>IR: gen = run()
66+
Note over RunPy,IR: generator setup with initial memory, settings, etc.
6667
```
6768

6869
## Inference and Sampling Sequence
@@ -127,6 +128,6 @@ sequenceDiagram
127128
- **Error/Edge Cases**: Assumes sufficient memory/GPUs; handles long contexts by left-truncation/padding. No built-in EOS handling (relies on max_len or app logic). Quantized weights require custom unpickling.
128129
- **Performance Notes**: MoE router/experts use JAX vmap/shard_map (serial per-token, inefficient for prod). Focus on correctness/single-host validation.
129130
- **Extensibility**: Modular Haiku design allows custom configs/modules. Generator interface suits serving multiple prompts concurrently.
130-
- **Dependencies & Setup**: `requirements.txt` (jax[cuda12_pip], haiku, etc.). Download ckpt via torrent/HF, place in checkpoints/.
131+
- **Dependencies & Setup**: `requirements.txt` which includes pinned versions: jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html (for CUDA12 support), jaxlib==0.4.25 (added in PR #315 to prevent version mismatch errors with jax during startup, especially on Windows), dm_haiku==0.0.12, numpy==1.26.4, sentencepiece==0.2.0. Download checkpoint via torrent or HuggingFace, place in checkpoints/.
131132

132133
This document captures the high-level design, derived from code analysis.

.exp/project-overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ C4Context
8989

9090
- **Supporting Files**:
9191
- `tokenizer.model`: Binary SentencePiece tokenizer (131k vocab).
92-
- `requirements.txt`: Dependencies (JAX CUDA, Haiku, NumPy, SentencePiece).
92+
- `requirements.txt`: Dependencies including jax[cuda12-pip]==0.4.25 with CUDA12 support (via JAX releases URL), jaxlib==0.4.25 (pinned via PR #315 to avoid incompatibility errors with jax), dm_haiku==0.0.12, numpy==1.26.4, sentencepiece==0.2.0.
9393
- `checkpoints/`: Directory for weights (must download `ckpt-0/`).
9494
- `pyproject.toml`: Linting config (Ruff).
9595

pr-analysis-315.md

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# PR #315: Workflow Design Impact Analysis
2+
3+
## Affected Workflows
4+
- **Grok-1 Inference and Sampling** (Workflow 1): This workflow is impacted because it relies on JAX for model initialization, distributed sharding via meshes and PJIT, and efficient generation. The PR's fix to requirements.txt ensures JAX and jaxlib versions are compatible, preventing import and startup errors in files like run.py, runners.py, and model.py referenced in the workflow definition.
5+
- **Model Loading and Initialization** (Workflow 2): Impacted as the workflow uses JAX for creating meshes, transforming Haiku functions with pjit, and sharding parameters during checkpoint restore in runners.py and checkpoint.py. The version pin avoids RuntimeError during these JAX-dependent operations.
6+
- **Model Forward Pass and Logits Computation** (Workflow 3): Impacted since computing logits involves JAX numpy operations, experimental maps, and sharded forward passes in model.py and runners.py. The PR resolves dependency issues that would halt execution at import stage.
7+
8+
## Grok-1 Inference and Sampling Analysis
9+
### Summary of design changes
10+
The PR adds `jaxlib==0.4.25` to requirements.txt to match jax version 0.4.25, fixing incompatibility errors reported in the PR description (e.g., on Windows). This affects the setup phase, making the workflow more robust by ensuring JAX libraries load correctly before any code execution. Benefits: Reliable installation across OS, avoids manual version updates. Implications: No runtime logic changes, but enhances accessibility for users running inference.
11+
12+
The design document (.exp/design-workflow-1-grok-1-inference-and-sampling.md) \"Dependencies & Setup\" section was updated to document this change. The existing Mermaid diagrams (Initialization Sequence and Inference Sequence) do not include setup, so no updates needed there. Below is a diff view of the Initialization Sequence, with addition highlighted in green:
13+
14+
```mermaid
15+
flowchart TD
16+
Setup["Environment Setup<br/>pip install -r requirements.txt<br/><br/>Added: jaxlib==0.4.25 pin<br/>for JAX compatibility"]:::added
17+
User[User]
18+
RunPy[run.py]
19+
IR[InferenceRunner]
20+
MR[ModelRunner]
21+
Model[model.py]
22+
Checkpoint[checkpoint.py]
23+
JAX[JAX Runtime]
24+
25+
Setup --> User --> RunPy --> IR --> MR --> Model
26+
MR -.->|unchanged| JAX
27+
MR --> Checkpoint --> MR
28+
IR --> IR
29+
RunPy --> IR
30+
31+
classDef added fill:#90EE90,stroke:#333,stroke-width:4px;
32+
classDef unchanged fill:#ffff00,stroke:#333,stroke-width:2px
33+
JAX:::unchanged
34+
```
35+
36+
37+
## Model Loading and Initialization Analysis
38+
### Summary of design changes
39+
Similar to Workflow 1, the PR ensures JAX/jaxlib compatibility for the core operations of model architecture definition, parameter sharding, and checkpoint loading. This prevents errors during mesh creation and pjit transformations. The change is in prerequisites, not altering sequences or components in the design. Benefits: Enables smooth distributed initialization across devices/hosts. The design doc (.exp/design-workflow-2-model-loading-and-initialization.md) does not explicitly list dependencies, but implicitly relies on them; no text update made as not documented previously.
40+
41+
Diff view of the Sequence Diagram, adding setup:
42+
43+
```mermaid
44+
flowchart TD
45+
Setup["Environment Setup<br/>pip install -r requirements.txt<br/><br/>Added: jaxlib==0.4.25"]:::added
46+
S["Script/User"]
47+
MR["ModelRunner"]
48+
MD["Model model.py"]
49+
CL["Checkpoint checkpoint.py"]
50+
JM["JAX Mesh"]
51+
52+
Setup --> S --> MR --> MD
53+
MR --> JM
54+
JM --> MR
55+
MR --> MR
56+
MR --> CL
57+
CL --> MR
58+
59+
classDef added fill:#90EE90,stroke:#333,stroke-width:4px;
60+
classDef changed fill:#ffff00,stroke:#333,stroke-width:2px;
61+
JM:::changed
62+
```
63+
64+
## Model Forward Pass and Logits Computation Analysis
65+
### Summary of design changes
66+
The PR's dependency fix supports the JAX-based forward pass through the transformer layers, including attention and MoE computations, by ensuring library imports succeed. No changes to the computation flow, inputs/outputs, or sharding. Potential benefit: Stable environment for evaluation or custom loops using logits_fn. Design doc (.exp/design-workflow-3-model-forward-pass-and-logits-computation.md) updated implicitly via project overview; no specific deps section.
67+
68+
To illustrate, a diff for a typical forward pass sequence would add the setup step similarly (abbreviated; refer to original doc for full):
69+
70+
```mermaid
71+
flowchart TD
72+
Setup["Environment Setup<br/>Ensure deps installed<br/>jaxlib fix"]:::added
73+
User[User]
74+
Runner["ModelRunner"]
75+
Model["LanguageModel"]
76+
JAX["JAX Engine"]
77+
78+
Setup --> User --> Runner --> Model --> JAX
79+
JAX --> Model --> Runner
80+
81+
classDef added fill:#90EE90,stroke:#333,stroke-width:4px;
82+
classDef changed fill:#ffff00,stroke:#333,stroke-width:2px;
83+
JAX:::changed
84+
```
85+
86+
Note: For Workflow 3, original diagrams show detailed forward and logits sequences; addition is pre-execution.

0 commit comments

Comments
 (0)