Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 32 additions & 42 deletions .exp/design-workflow-1-grok-1-inference-and-sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The \"Grok-1 Inference and Sampling\" workflow provides the machinery to load th

Key inputs: Checkpoint in `./checkpoints/ckpt-0/`, `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int).
Outputs: Generated text strings.
Entry points: `run.py` for test run, or `InferenceRunner().run()` generator for streaming requests.
Entry points: `run.py` for secure automated test run with encryption/logging, or `InferenceRunner.predict(Request)` for single-request autoregressive generation.
Relevant files: `run.py`, `runners.py`, `model.py`, `checkpoint.py`, `tokenizer.model`.

The workflow orchestrates model loading, compilation of sharded compute functions, prompt processing (prefill KV cache while sampling first token), and iterative single-token generation using cached attention keys/values, until max length or EOS.
Expand Down Expand Up @@ -52,56 +52,46 @@ sequenceDiagram
participant Model as model.py
participant Checkpoint as checkpoint.py
participant JAX as JAX Runtime
User->>RunPy: Execute main()
RunPy->>IR: Create with config, MR, paths, meshes
participant Sec as Security (run.py)
User->>RunPy: Execute main() with env vars, logging, key mgmt
RunPy->>IR: initialize_inference_runner(config, paths from env, meshes)
IR->>MR: initialize(dummy_data, meshes)
MR->>Model: model.initialize(), fprop_dtype=bf16
Note over MR,JAX: Calculate batch sizes, create mesh (data, model axes)
MR->>MR: hk.transform forward/logits_fn with pjit sharding
MR->>Checkpoint: load_or_init -> restore(shapes, mesh, sharding)
Checkpoint->>MR: Sharded params (TrainingState)
IR->>IR: Load tokenizer, compile pjit funcs (sample_step, prefill_memory, new_memory) with shardings
IR->>IR: Precompile with dummy prompts for pad_sizes
RunPy->>IR: gen = run() // generator setup with initial memory, settings, etc.
Note over MR,JAX: Calculate batch sizes, create mesh
MR->>MR: hk.transform forward/logits_fn/eval_forward with pjit
MR->>Checkpoint: load_or_init(True) -> restore from load path
Checkpoint->>MR: Sharded params
IR->>IR: set self.params/vocab_size, load tokenizer, text_to_token_ids
Note over IR: Simplified, no sampling pjit/precompile
RunPy->>IR: generate_text -> predict(Request)
RunPy->>Sec: encrypt_message & decrypt_message output
```

## Inference and Sampling Sequence

```mermaid
sequenceDiagram
participant Gen as Generator (run())
participant Req as Request
participant Tok as Tokenizer
participant Prefill as prefill_memory
participant Step as sample_step
participant LM as LM forward
participant Samp as sample_token
participant Mem as KV Memory
participant Out as Output

Note over Gen: Initial setup: memory, rngs, settings, last_output

Gen->>Req: yield (wait for input)
Req->>Gen: send Request(prompt, temp, p, seed, max_len)
Gen->>Tok: encode(prompt) -> tokens
Gen->>Gen: pad tokens, create settings, active=1
Gen->>Prefill: call prefill_memory(tokens, len, new_settings, slot)
Prefill->>LM: hk_forward(tokens, new_mem, length, active) // process prompt
LM->>Samp: sample_token from logits // sample first token?
Prefill->>Mem: update KV cache with prompt tokens + first?
Prefill->>Gen: updated rngs, last_output, memory, settings
loop Autoregressive Sampling (while active and < max_len)
Gen->>Step: sample_step(params, rngs, last_output, memory, settings)
Step->>LM: hk_forward(last_token, memory) // decode step
LM->>Samp: sample_token(logits, settings)
Step->>Mem: update memory with new KV (donate old)
Step->>Gen: new rngs, sample_output, memory
Gen->>Gen: append token to sequence, copy to host
alt Reached max_len or EOS?
Gen->>Out: decode all tokens -> yield text
Gen->>Gen: deactivate slot, free for new req
end
participant RunPy
participant IR
participant Tok
participant Forward
participant Samp
participant Sec
RunPy->>IR: generate_text -> predict(Request)
IR->>Tok: text_to_token_ids(prompt) -> token_ids
Note over IR: settings: temp, nucleus_p, vocab mask, active
loop up to max_len (stop on eos/pad)
IR->>IR: split step_rng from gen_rng
IR->>Forward: eval_forward(token_ids) -> lm_outputs
Forward->>IR: logits
IR->>Samp: sample_token(step_rng, lm_outputs, settings) -> new_token
Samp->>IR:
IR->>IR: token_ids = concat(token_ids, new_token)
end
IR->>IR: tokenizer.decode(token_ids.squeeze()) -> text
IR->>RunPy: text
RunPy->>Sec: encrypt_message(text) -> decrypt -> print
Note over IR: Simplified single loop, no KV cache/batching (change); fixes for correctness
```

## Sharding and Distributed Execution
Expand Down
47 changes: 17 additions & 30 deletions .exp/design-workflow-2-model-loading-and-initialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,40 +46,27 @@ The process ensures efficient loading of 314B parameters, correct mapping betwee

```mermaid
sequenceDiagram
participant S as Script/User
participant IR as InferenceRunner (caller)
participant MR as ModelRunner
participant MD as Model (model.py)
participant CL as Checkpoint (checkpoint.py)
participant MD as model.py
participant CL as checkpoint.py
participant JM as JAX Mesh
participant D as Devices

S->>+MR: new ModelRunner(config)
MR->>+MD: model.make(mesh) [in init]
Note right of MR: initialize(local_mesh_config, between_hosts_config, init_data)
MR->>+JM: make_mesh(configs)
JM-->>-MR: mesh
MR->>+MR: hk.transform(forward) & pjit
MR->>+MR: compute state_sharding via eval_shape & partition_rules

alt Load from Checkpoint
MR->>+MR: load_or_init(init_data, from_checkpoint=True)
MR->>+MR: eval_shape(init_fn) -> shapes
MR->>+CL: restore(path, shapes, mesh, sharding, params_only=True)
Note right of CL: load_tensors(): parallel unpickle sharded tensors<br/>from ckpt-0/tensorXXXX_YYY
CL->>+JM: host_local_to_global_array(state, mesh, sharding)
JM->>+D: Shard params across devices/hosts
D-->>-JM:
JM-->>-CL: Sharded state
CL-->>-MR: params
else Random Init
MR->>+MR: load_or_init(init_data, from_checkpoint=False)
MR->>+MR: init_fn(rng, init_data) -> forward.init(rng, inputs)
Note right of MR: Generates random params matching shapes
MR->>+JM: Shard new params
JM-->>-MR: Sharded params
IR->>MR: initialize(dummy_data, local/between configs) [new call]
Note right of MR: Compute batch sizes, make mesh, transform forward fns, sharding
IR->>MR: load_or_init(dummy_data, from_checkpoint=True) [updated param]
alt Load from Checkpoint (default now)
MR->>MR: eval_shape -> shapes
MR->>CL: restore(path from load or default, shapes, mesh, sharding)
Note right of CL: parallel load, shard tensors
CL->>JM: to global sharded arrays
JM->>D: distribute
CL-->>MR: params
else Random Init (if no path)
MR->>MR: init_fn(rng, dummy_data) [random]
MR->>JM: shard
end

MR-->>-S: Sharded TrainingState(params)
MR-->>IR: state.params -> set self.params
```

## Additional Design Aspects
Expand Down
Binary file added __pycache__/run.cpython-313.pyc
Binary file not shown.
Binary file added __pycache__/runners.cpython-313.pyc
Binary file not shown.
69 changes: 69 additions & 0 deletions pr-analysis-352.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# PR #352: Workflow Design Impact Analysis

## Affected Workflows
- **Grok-1 Inference and Sampling**: Changed files `run.py` and `runners.py` are central to this workflow's entry points and implementation. The PR introduces automation, security, and simplified inference, matching the description of automated model setup, inference, sampling techniques, and compliance features. Justification: Direct modifications to core code, added features like encryption and logging. [PR #352](https://github.com/xai-org/grok-1/pull/352)
- **Model Loading and Initialization**: Modifications in `runners.py` alter how `ModelRunner.load_or_init` is invoked and setup, now properly calling initialize before loading from checkpoint with distributed support. Relevant to entry point and relevant files.

## Workflow 1 Analysis
### Summary of design changes
The PR enhances the workflow with security and automation while simplifying the inference mechanism. Key changes:
- **run.py**: Added env var configuration for paths and encryption key, Fernet-based encryption/decryption of output, comprehensive logging and error handling in modular functions. This implements secure data handling and monitoring for DoD/NSA compliance.
- **runners.py**: Simplified `InferenceRunner` with proper setup calling `ModelRunner.initialize` and loading from checkpoint, added `predict(Request)` for single-request generation using autoregressive loop with sampling (fixed for RNG, settings, stop conditions). Removed batched generator logic.
- Affected aspects: Added security steps, changed inference to simple non-cached loop, restored loading.
- How implemented: Env vars, Fernet, fixed predict loop using eval_forward and sample_token.
- Benefits: Compliance-ready, automated, simple for single use.
- Implications: Less efficient for long/batched, but fixed for correctness and loading real model.

```mermaid
sequenceDiagram
participant User
participant RunPy as run.py
participant IR as InferenceRunner
participant MR as ModelRunner
participant CL as checkpoint.py
participant JAX as JAX Runtime
participant Sec as Security Features

rect green Addition
User->>RunPy: main() env vars, logging
RunPy->>Sec: encrypt/decrypt output
Sec->>User: secure print
end
rect yellow Change
RunPy->>IR: create, initialize
IR->>MR: initialize + load_or_init (load ckpt)
MR->>CL: restore sharded
IR->>IR: predict loop (full forward each step, no cache)
end
rect red Removal
Note over IR: Removed batch gen, pjit sampling funcs
end
```

## Workflow 2 Analysis
### Summary of design changes
Updated invocation in `InferenceRunner.initialize()` to call `ModelRunner.initialize()` before `load_or_init` with `from_checkpoint=True`, ensuring full distributed sharding and checkpoint loading. Minor updates to types/logging.
- Aspects affected: Invocation sequence, enabled loading.
- Implementation: Set checkpoint_path from load param, proper setup.
- Benefits: Preserves distributed loading with automation.
- Implications: Fixed original broken flow, aligns with automated setup.

```mermaid
sequenceDiagram
participant IR
participant MR
participant JM as Mesh
participant CL

rect green Addition/Fix
IR->>MR: initialize(dummy, meshes)
end
rect yellow Change
IR->>MR: load_or_init(..., True) load path
MR->>JM: sharding setup
MR->>CL: restore
end
rect red Removal (original)
Note over IR: No longer skips setup
end
```