diff --git a/.exp/design-workflow-1-grok-1-inference-and-sampling.md b/.exp/design-workflow-1-grok-1-inference-and-sampling.md index 31cae7b..df4a3ac 100644 --- a/.exp/design-workflow-1-grok-1-inference-and-sampling.md +++ b/.exp/design-workflow-1-grok-1-inference-and-sampling.md @@ -6,7 +6,7 @@ The \"Grok-1 Inference and Sampling\" workflow provides the machinery to load th Key inputs: Checkpoint in `./checkpoints/ckpt-0/`, `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int). Outputs: Generated text strings. -Entry points: `run.py` for test run, or `InferenceRunner().run()` generator for streaming requests. +Entry points: `run.py` for secure automated test run with encryption/logging, or `InferenceRunner.predict(Request)` for single-request autoregressive generation. Relevant files: `run.py`, `runners.py`, `model.py`, `checkpoint.py`, `tokenizer.model`. The workflow orchestrates model loading, compilation of sharded compute functions, prompt processing (prefill KV cache while sampling first token), and iterative single-token generation using cached attention keys/values, until max length or EOS. @@ -52,56 +52,46 @@ sequenceDiagram participant Model as model.py participant Checkpoint as checkpoint.py participant JAX as JAX Runtime - User->>RunPy: Execute main() - RunPy->>IR: Create with config, MR, paths, meshes + participant Sec as Security (run.py) + User->>RunPy: Execute main() with env vars, logging, key mgmt + RunPy->>IR: initialize_inference_runner(config, paths from env, meshes) IR->>MR: initialize(dummy_data, meshes) MR->>Model: model.initialize(), fprop_dtype=bf16 - Note over MR,JAX: Calculate batch sizes, create mesh (data, model axes) - MR->>MR: hk.transform forward/logits_fn with pjit sharding - MR->>Checkpoint: load_or_init -> restore(shapes, mesh, sharding) - Checkpoint->>MR: Sharded params (TrainingState) - IR->>IR: Load tokenizer, compile pjit funcs (sample_step, prefill_memory, new_memory) with shardings - IR->>IR: Precompile with dummy prompts for pad_sizes - RunPy->>IR: gen = run() // generator setup with initial memory, settings, etc. + Note over MR,JAX: Calculate batch sizes, create mesh + MR->>MR: hk.transform forward/logits_fn/eval_forward with pjit + MR->>Checkpoint: load_or_init(True) -> restore from load path + Checkpoint->>MR: Sharded params + IR->>IR: set self.params/vocab_size, load tokenizer, text_to_token_ids + Note over IR: Simplified, no sampling pjit/precompile + RunPy->>IR: generate_text -> predict(Request) + RunPy->>Sec: encrypt_message & decrypt_message output ``` ## Inference and Sampling Sequence ```mermaid sequenceDiagram - participant Gen as Generator (run()) - participant Req as Request - participant Tok as Tokenizer - participant Prefill as prefill_memory - participant Step as sample_step - participant LM as LM forward - participant Samp as sample_token - participant Mem as KV Memory - participant Out as Output - - Note over Gen: Initial setup: memory, rngs, settings, last_output - - Gen->>Req: yield (wait for input) - Req->>Gen: send Request(prompt, temp, p, seed, max_len) - Gen->>Tok: encode(prompt) -> tokens - Gen->>Gen: pad tokens, create settings, active=1 - Gen->>Prefill: call prefill_memory(tokens, len, new_settings, slot) - Prefill->>LM: hk_forward(tokens, new_mem, length, active) // process prompt - LM->>Samp: sample_token from logits // sample first token? - Prefill->>Mem: update KV cache with prompt tokens + first? - Prefill->>Gen: updated rngs, last_output, memory, settings - loop Autoregressive Sampling (while active and < max_len) - Gen->>Step: sample_step(params, rngs, last_output, memory, settings) - Step->>LM: hk_forward(last_token, memory) // decode step - LM->>Samp: sample_token(logits, settings) - Step->>Mem: update memory with new KV (donate old) - Step->>Gen: new rngs, sample_output, memory - Gen->>Gen: append token to sequence, copy to host - alt Reached max_len or EOS? - Gen->>Out: decode all tokens -> yield text - Gen->>Gen: deactivate slot, free for new req - end + participant RunPy + participant IR + participant Tok + participant Forward + participant Samp + participant Sec + RunPy->>IR: generate_text -> predict(Request) + IR->>Tok: text_to_token_ids(prompt) -> token_ids + Note over IR: settings: temp, nucleus_p, vocab mask, active + loop up to max_len (stop on eos/pad) + IR->>IR: split step_rng from gen_rng + IR->>Forward: eval_forward(token_ids) -> lm_outputs + Forward->>IR: logits + IR->>Samp: sample_token(step_rng, lm_outputs, settings) -> new_token + Samp->>IR: + IR->>IR: token_ids = concat(token_ids, new_token) end + IR->>IR: tokenizer.decode(token_ids.squeeze()) -> text + IR->>RunPy: text + RunPy->>Sec: encrypt_message(text) -> decrypt -> print + Note over IR: Simplified single loop, no KV cache/batching (change); fixes for correctness ``` ## Sharding and Distributed Execution diff --git a/.exp/design-workflow-2-model-loading-and-initialization.md b/.exp/design-workflow-2-model-loading-and-initialization.md index 5f904ef..0d9a228 100644 --- a/.exp/design-workflow-2-model-loading-and-initialization.md +++ b/.exp/design-workflow-2-model-loading-and-initialization.md @@ -46,40 +46,27 @@ The process ensures efficient loading of 314B parameters, correct mapping betwee ```mermaid sequenceDiagram - participant S as Script/User + participant IR as InferenceRunner (caller) participant MR as ModelRunner - participant MD as Model (model.py) - participant CL as Checkpoint (checkpoint.py) + participant MD as model.py + participant CL as checkpoint.py participant JM as JAX Mesh participant D as Devices - - S->>+MR: new ModelRunner(config) - MR->>+MD: model.make(mesh) [in init] - Note right of MR: initialize(local_mesh_config, between_hosts_config, init_data) - MR->>+JM: make_mesh(configs) - JM-->>-MR: mesh - MR->>+MR: hk.transform(forward) & pjit - MR->>+MR: compute state_sharding via eval_shape & partition_rules - - alt Load from Checkpoint - MR->>+MR: load_or_init(init_data, from_checkpoint=True) - MR->>+MR: eval_shape(init_fn) -> shapes - MR->>+CL: restore(path, shapes, mesh, sharding, params_only=True) - Note right of CL: load_tensors(): parallel unpickle sharded tensors
from ckpt-0/tensorXXXX_YYY - CL->>+JM: host_local_to_global_array(state, mesh, sharding) - JM->>+D: Shard params across devices/hosts - D-->>-JM: - JM-->>-CL: Sharded state - CL-->>-MR: params - else Random Init - MR->>+MR: load_or_init(init_data, from_checkpoint=False) - MR->>+MR: init_fn(rng, init_data) -> forward.init(rng, inputs) - Note right of MR: Generates random params matching shapes - MR->>+JM: Shard new params - JM-->>-MR: Sharded params + IR->>MR: initialize(dummy_data, local/between configs) [new call] + Note right of MR: Compute batch sizes, make mesh, transform forward fns, sharding + IR->>MR: load_or_init(dummy_data, from_checkpoint=True) [updated param] + alt Load from Checkpoint (default now) + MR->>MR: eval_shape -> shapes + MR->>CL: restore(path from load or default, shapes, mesh, sharding) + Note right of CL: parallel load, shard tensors + CL->>JM: to global sharded arrays + JM->>D: distribute + CL-->>MR: params + else Random Init (if no path) + MR->>MR: init_fn(rng, dummy_data) [random] + MR->>JM: shard end - - MR-->>-S: Sharded TrainingState(params) + MR-->>IR: state.params -> set self.params ``` ## Additional Design Aspects diff --git a/__pycache__/run.cpython-313.pyc b/__pycache__/run.cpython-313.pyc new file mode 100644 index 0000000..5059ab1 Binary files /dev/null and b/__pycache__/run.cpython-313.pyc differ diff --git a/__pycache__/runners.cpython-313.pyc b/__pycache__/runners.cpython-313.pyc new file mode 100644 index 0000000..6bc6f60 Binary files /dev/null and b/__pycache__/runners.cpython-313.pyc differ diff --git a/pr-analysis-352.md b/pr-analysis-352.md new file mode 100644 index 0000000..57357a3 --- /dev/null +++ b/pr-analysis-352.md @@ -0,0 +1,69 @@ +# PR #352: Workflow Design Impact Analysis + +## Affected Workflows +- **Grok-1 Inference and Sampling**: Changed files `run.py` and `runners.py` are central to this workflow's entry points and implementation. The PR introduces automation, security, and simplified inference, matching the description of automated model setup, inference, sampling techniques, and compliance features. Justification: Direct modifications to core code, added features like encryption and logging. [PR #352](https://github.com/xai-org/grok-1/pull/352) +- **Model Loading and Initialization**: Modifications in `runners.py` alter how `ModelRunner.load_or_init` is invoked and setup, now properly calling initialize before loading from checkpoint with distributed support. Relevant to entry point and relevant files. + +## Workflow 1 Analysis +### Summary of design changes +The PR enhances the workflow with security and automation while simplifying the inference mechanism. Key changes: +- **run.py**: Added env var configuration for paths and encryption key, Fernet-based encryption/decryption of output, comprehensive logging and error handling in modular functions. This implements secure data handling and monitoring for DoD/NSA compliance. +- **runners.py**: Simplified `InferenceRunner` with proper setup calling `ModelRunner.initialize` and loading from checkpoint, added `predict(Request)` for single-request generation using autoregressive loop with sampling (fixed for RNG, settings, stop conditions). Removed batched generator logic. +- Affected aspects: Added security steps, changed inference to simple non-cached loop, restored loading. +- How implemented: Env vars, Fernet, fixed predict loop using eval_forward and sample_token. +- Benefits: Compliance-ready, automated, simple for single use. +- Implications: Less efficient for long/batched, but fixed for correctness and loading real model. + +```mermaid +sequenceDiagram + participant User + participant RunPy as run.py + participant IR as InferenceRunner + participant MR as ModelRunner + participant CL as checkpoint.py + participant JAX as JAX Runtime + participant Sec as Security Features + + rect green Addition + User->>RunPy: main() env vars, logging + RunPy->>Sec: encrypt/decrypt output + Sec->>User: secure print + end + rect yellow Change + RunPy->>IR: create, initialize + IR->>MR: initialize + load_or_init (load ckpt) + MR->>CL: restore sharded + IR->>IR: predict loop (full forward each step, no cache) + end + rect red Removal + Note over IR: Removed batch gen, pjit sampling funcs + end +``` + +## Workflow 2 Analysis +### Summary of design changes +Updated invocation in `InferenceRunner.initialize()` to call `ModelRunner.initialize()` before `load_or_init` with `from_checkpoint=True`, ensuring full distributed sharding and checkpoint loading. Minor updates to types/logging. +- Aspects affected: Invocation sequence, enabled loading. +- Implementation: Set checkpoint_path from load param, proper setup. +- Benefits: Preserves distributed loading with automation. +- Implications: Fixed original broken flow, aligns with automated setup. + +```mermaid +sequenceDiagram + participant IR + participant MR + participant JM as Mesh + participant CL + + rect green Addition/Fix + IR->>MR: initialize(dummy, meshes) + end + rect yellow Change + IR->>MR: load_or_init(..., True) load path + MR->>JM: sharding setup + MR->>CL: restore + end + rect red Removal (original) + Note over IR: No longer skips setup + end +```